synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +104 -6
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,364 @@
|
|
1
|
+
"""CrafterClassicEnvironment — thin wrapper exposing CrafterEngine via StatefulEnvironment API."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import List, Optional, Any, Dict, Union
|
6
|
+
import dataclasses
|
7
|
+
import logging
|
8
|
+
import time
|
9
|
+
|
10
|
+
# Import logging configuration to suppress JAX debug messages
|
11
|
+
from .config_logging import safe_compare
|
12
|
+
|
13
|
+
# Import tracing abstractions
|
14
|
+
from synth_ai.tracing_v3.abstractions import (
|
15
|
+
RuntimeEvent,
|
16
|
+
SessionEventMarkovBlanketMessage,
|
17
|
+
TimeRecord,
|
18
|
+
)
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
from synth_ai.environments.examples.crafter_classic.engine import (
|
23
|
+
CrafterEngine,
|
24
|
+
CrafterPrivateState,
|
25
|
+
CrafterPublicState,
|
26
|
+
CrafterEngineSnapshot,
|
27
|
+
)
|
28
|
+
from synth_ai.environments.examples.crafter_classic.taskset import CrafterTaskInstance
|
29
|
+
from synth_ai.environments.environment.shared_engine import (
|
30
|
+
GetObservationCallable,
|
31
|
+
InternalObservation,
|
32
|
+
)
|
33
|
+
from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
34
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
35
|
+
from synth_ai.environments.environment.tools import (
|
36
|
+
AbstractTool,
|
37
|
+
EnvToolCall,
|
38
|
+
ToolResult,
|
39
|
+
TOOL_REGISTRY,
|
40
|
+
register_tool,
|
41
|
+
)
|
42
|
+
from pydantic import BaseModel, Field
|
43
|
+
|
44
|
+
|
45
|
+
# --- Tool Definition ---
|
46
|
+
class CrafterActionInput(BaseModel):
|
47
|
+
action: int = Field(..., description="Integer action for the Crafter environment.")
|
48
|
+
|
49
|
+
|
50
|
+
class CrafterInteractTool(AbstractTool):
|
51
|
+
name = "interact"
|
52
|
+
description = "Performs an action in the Crafter environment."
|
53
|
+
call_schema = CrafterActionInput
|
54
|
+
result_schema = ToolResult
|
55
|
+
|
56
|
+
def __init__(self, engine: CrafterEngine, session_tracer: Optional[Any] = None):
|
57
|
+
self.engine = engine
|
58
|
+
self.session_tracer = session_tracer
|
59
|
+
|
60
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
61
|
+
try:
|
62
|
+
# Store state before execution
|
63
|
+
state_before = {"action_args": call.args}
|
64
|
+
|
65
|
+
validated_args = self.call_schema(**call.args)
|
66
|
+
action_to_pass = self.engine._validate_action_engine(validated_args.action)
|
67
|
+
|
68
|
+
# Execute the engine step
|
69
|
+
priv_state, pub_state = await self.engine._step_engine(action_to_pass)
|
70
|
+
|
71
|
+
# Store state after execution
|
72
|
+
state_after = {
|
73
|
+
"engine_result": {"private_state": priv_state, "public_state": pub_state}
|
74
|
+
}
|
75
|
+
|
76
|
+
# Record runtime event for tool execution
|
77
|
+
if (
|
78
|
+
self.session_tracer
|
79
|
+
and hasattr(self.session_tracer, "current_session")
|
80
|
+
and self.session_tracer.current_session
|
81
|
+
):
|
82
|
+
runtime_execution_event = RuntimeEvent()
|
83
|
+
runtime_execution_event.time_record = TimeRecord()
|
84
|
+
runtime_execution_event.time_record.event_time = time.time()
|
85
|
+
runtime_execution_event.time_record.message_time = None
|
86
|
+
runtime_execution_event.system_instance_id = "crafter_interact_tool"
|
87
|
+
runtime_execution_event.system_state_before = state_before
|
88
|
+
runtime_execution_event.system_state_after = state_after
|
89
|
+
runtime_execution_event.actions = [action_to_pass]
|
90
|
+
runtime_execution_event.metadata = {"execution_step": "engine_action"}
|
91
|
+
# Add directly to event history, bypassing timestep requirement
|
92
|
+
self.session_tracer.current_session.add_event(runtime_execution_event)
|
93
|
+
|
94
|
+
return ToolResult(
|
95
|
+
ok=True,
|
96
|
+
payload={
|
97
|
+
"public_state": pub_state,
|
98
|
+
"private_state": priv_state,
|
99
|
+
},
|
100
|
+
)
|
101
|
+
except Exception as e:
|
102
|
+
pub_state_on_error = self.engine._get_public_state_from_env() # Use engine helper
|
103
|
+
# Get a safe private state for error cases
|
104
|
+
health_dead = safe_compare(0, self.engine.env._player.health, ">=")
|
105
|
+
step_exceeded = safe_compare(self.engine.env._length, self.engine.env._step, "<=")
|
106
|
+
priv_state_on_error = self.engine._get_private_state_from_env(
|
107
|
+
0, health_dead, step_exceeded
|
108
|
+
)
|
109
|
+
return ToolResult(
|
110
|
+
ok=False,
|
111
|
+
error=str(e),
|
112
|
+
payload={
|
113
|
+
"public_state": pub_state_on_error,
|
114
|
+
"private_state": priv_state_on_error,
|
115
|
+
},
|
116
|
+
)
|
117
|
+
|
118
|
+
|
119
|
+
# Default observation callable (can be customized via __init__)
|
120
|
+
class SynthCrafterObservationCallable(GetObservationCallable):
|
121
|
+
async def get_observation(
|
122
|
+
self, pub: CrafterPublicState, priv: CrafterPrivateState
|
123
|
+
) -> InternalObservation:
|
124
|
+
# Example: return a dictionary combining public and selected private info
|
125
|
+
# Actual observation structure depends on agent's needs.
|
126
|
+
obs_dict: Dict[str, Any] = dataclasses.asdict(pub) # type: ignore
|
127
|
+
obs_dict["reward_last_step"] = priv.reward_last_step
|
128
|
+
obs_dict["total_reward_episode"] = priv.total_reward_episode
|
129
|
+
obs_dict["terminated"] = priv.terminated
|
130
|
+
obs_dict["truncated"] = priv.truncated
|
131
|
+
if pub.error_info:
|
132
|
+
obs_dict["tool_error"] = pub.error_info
|
133
|
+
return obs_dict
|
134
|
+
|
135
|
+
|
136
|
+
class CrafterClassicEnvironment(StatefulEnvironment, ReproducibleEnvironment[CrafterEngine]):
|
137
|
+
"""Environment wrapper bridging agent tool‑calls to `crafter.Env` dynamics."""
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
task_instance: "CrafterTaskInstance",
|
142
|
+
custom_step_obs: Optional[GetObservationCallable] = None,
|
143
|
+
custom_ckpt_obs: Optional[GetObservationCallable] = None,
|
144
|
+
session_tracer: Optional[Any] = None, # SessionTracer from higher level
|
145
|
+
) -> None:
|
146
|
+
self.name = "CrafterClassic"
|
147
|
+
self.task_instance = task_instance
|
148
|
+
self.custom_step_observation_callable = custom_step_obs or SynthCrafterObservationCallable()
|
149
|
+
self.custom_checkpoint_observation_callable = (
|
150
|
+
custom_ckpt_obs or SynthCrafterObservationCallable()
|
151
|
+
)
|
152
|
+
self.engine = CrafterEngine(task_instance)
|
153
|
+
self.session_tracer = session_tracer # Store tracer for runtime events
|
154
|
+
|
155
|
+
self._interact_tool = CrafterInteractTool(self.engine, session_tracer=session_tracer)
|
156
|
+
if self._interact_tool.name not in TOOL_REGISTRY:
|
157
|
+
register_tool(self._interact_tool)
|
158
|
+
|
159
|
+
# ────────────────────────────────────────────────────────────────────
|
160
|
+
# Lifecycle helpers
|
161
|
+
# ────────────────────────────────────────────────────────────────────
|
162
|
+
|
163
|
+
async def initialize(self, seed: Optional[int] = None) -> InternalObservation: # type: ignore[override]
|
164
|
+
# Check if seed was provided in task instance metadata
|
165
|
+
if (
|
166
|
+
seed is None
|
167
|
+
and hasattr(self.task_instance, "metadata")
|
168
|
+
and hasattr(self.task_instance.metadata, "seed")
|
169
|
+
):
|
170
|
+
seed = self.task_instance.metadata.seed
|
171
|
+
# Check if seed was provided in initial_engine_snapshot
|
172
|
+
elif (
|
173
|
+
seed is None
|
174
|
+
and hasattr(self.task_instance, "initial_engine_snapshot")
|
175
|
+
and isinstance(self.task_instance.initial_engine_snapshot, dict)
|
176
|
+
):
|
177
|
+
seed = self.task_instance.initial_engine_snapshot.get("seed")
|
178
|
+
|
179
|
+
# Initialize with seed from various sources
|
180
|
+
|
181
|
+
priv, pub = await self.engine._reset_engine(seed=seed)
|
182
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
183
|
+
|
184
|
+
async def terminate(self) -> InternalObservation: # type: ignore[override]
|
185
|
+
pub = self.engine._get_public_state_from_env()
|
186
|
+
priv = self.engine._get_private_state_from_env(0, True, False) # Terminated state
|
187
|
+
priv.terminated = True
|
188
|
+
obs_dict = {"status": "Environment terminated."}
|
189
|
+
return await self._to_observation(
|
190
|
+
priv, pub, self.custom_step_observation_callable, extra_obs=obs_dict
|
191
|
+
)
|
192
|
+
|
193
|
+
# ────────────────────────────────────────────────────────────────────
|
194
|
+
# Step + checkpoint
|
195
|
+
# ────────────────────────────────────────────────────────────────────
|
196
|
+
|
197
|
+
def validate_tool_calls(
|
198
|
+
self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
|
199
|
+
) -> EnvToolCall:
|
200
|
+
# Store the original tool calls for tracing
|
201
|
+
state_before = {"tool_calls": tool_calls}
|
202
|
+
|
203
|
+
# Normalize and validate to a single EnvToolCall (same as Sokoban)
|
204
|
+
if isinstance(tool_calls, list):
|
205
|
+
if not tool_calls:
|
206
|
+
raise ValueError("Received empty list of tool calls.")
|
207
|
+
if isinstance(tool_calls[0], list):
|
208
|
+
if not tool_calls[0]:
|
209
|
+
raise ValueError("Received empty inner list of tool calls.")
|
210
|
+
agent_call = tool_calls[0][0]
|
211
|
+
else:
|
212
|
+
agent_call = tool_calls[0]
|
213
|
+
elif isinstance(tool_calls, EnvToolCall):
|
214
|
+
agent_call = tool_calls
|
215
|
+
else:
|
216
|
+
raise TypeError(f"Unexpected type for tool_calls: {type(tool_calls)}")
|
217
|
+
|
218
|
+
if not isinstance(agent_call, EnvToolCall):
|
219
|
+
raise TypeError(f"Processed call is not EnvToolCall: {type(agent_call)}")
|
220
|
+
if agent_call.tool != "interact":
|
221
|
+
raise ValueError(f"Unknown tool: {agent_call.tool}. Expected 'interact'.")
|
222
|
+
|
223
|
+
# Record runtime event for tool call validation
|
224
|
+
if (
|
225
|
+
self.session_tracer
|
226
|
+
and hasattr(self.session_tracer, "current_session")
|
227
|
+
and self.session_tracer.current_session
|
228
|
+
):
|
229
|
+
runtime_validation_event = RuntimeEvent()
|
230
|
+
runtime_validation_event.time_record = TimeRecord()
|
231
|
+
runtime_validation_event.time_record.event_time = time.time()
|
232
|
+
runtime_validation_event.time_record.message_time = None
|
233
|
+
runtime_validation_event.system_instance_id = "crafter_environment"
|
234
|
+
runtime_validation_event.system_state_before = state_before
|
235
|
+
runtime_validation_event.system_state_after = {"validated_call": agent_call}
|
236
|
+
runtime_validation_event.metadata = {"validation_step": "tool_call_validation"}
|
237
|
+
# Add directly to event history, bypassing timestep requirement
|
238
|
+
self.session_tracer.current_session.add_event(runtime_validation_event)
|
239
|
+
|
240
|
+
return agent_call
|
241
|
+
|
242
|
+
async def step(
|
243
|
+
self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
|
244
|
+
) -> InternalObservation: # type: ignore[override]
|
245
|
+
step_start_time = time.time()
|
246
|
+
agent_call = self.validate_tool_calls(tool_calls)
|
247
|
+
interact_start = time.time()
|
248
|
+
tool_result: ToolResult = await self._interact_tool(agent_call)
|
249
|
+
interact_time = time.time() - interact_start
|
250
|
+
|
251
|
+
payload_dict = tool_result.payload
|
252
|
+
pub_state: CrafterPublicState
|
253
|
+
priv_state: CrafterPrivateState
|
254
|
+
|
255
|
+
if tool_result.ok:
|
256
|
+
# payload contains the actual state objects from the interact tool
|
257
|
+
priv_state = payload_dict.get("private_state")
|
258
|
+
pub_state = payload_dict.get("public_state")
|
259
|
+
|
260
|
+
# Validate we got the expected state objects
|
261
|
+
if not isinstance(priv_state, CrafterPrivateState) or not isinstance(
|
262
|
+
pub_state, CrafterPublicState
|
263
|
+
):
|
264
|
+
logger.error(
|
265
|
+
f"Invalid state types in payload: priv={type(priv_state)}, pub={type(pub_state)}"
|
266
|
+
)
|
267
|
+
# Fall back to getting current state
|
268
|
+
pub_state = self.engine._get_public_state_from_env()
|
269
|
+
health_dead = safe_compare(0, self.engine.env._player.health, ">=")
|
270
|
+
step_exceeded = safe_compare(self.engine.env._length, self.engine.env._step, "<=")
|
271
|
+
priv_state = self.engine._get_private_state_from_env(0, health_dead, step_exceeded)
|
272
|
+
pub_state.error_info = "Invalid state types in tool result"
|
273
|
+
else:
|
274
|
+
# Tool call failed, use states from payload if available, otherwise get current state
|
275
|
+
priv_state = payload_dict.get("private_state")
|
276
|
+
pub_state = payload_dict.get("public_state")
|
277
|
+
|
278
|
+
if not isinstance(priv_state, CrafterPrivateState) or not isinstance(
|
279
|
+
pub_state, CrafterPublicState
|
280
|
+
):
|
281
|
+
# Fall back to getting current state
|
282
|
+
pub_state = self.engine._get_public_state_from_env()
|
283
|
+
health_dead = safe_compare(0, self.engine.env._player.health, ">=")
|
284
|
+
step_exceeded = safe_compare(self.engine.env._length, self.engine.env._step, "<=")
|
285
|
+
priv_state = self.engine._get_private_state_from_env(0, health_dead, step_exceeded)
|
286
|
+
|
287
|
+
if tool_result.error:
|
288
|
+
pub_state.error_info = tool_result.error
|
289
|
+
|
290
|
+
obs = await self._to_observation(
|
291
|
+
priv_state, pub_state, self.custom_step_observation_callable
|
292
|
+
)
|
293
|
+
total_step_time = time.time() - step_start_time
|
294
|
+
logger.info(
|
295
|
+
f"CrafterClassic step completed in {total_step_time:.3f}s (interact: {interact_time:.3f}s)"
|
296
|
+
)
|
297
|
+
return obs
|
298
|
+
|
299
|
+
async def checkpoint(self) -> InternalObservation: # type: ignore[override]
|
300
|
+
engine_snapshot: CrafterEngineSnapshot = await self.engine._serialize_engine()
|
301
|
+
priv = self.engine._get_private_state_from_env(0, False, False) # Get current state for obs
|
302
|
+
pub = self.engine._get_public_state_from_env()
|
303
|
+
obs_data = await self._to_observation(
|
304
|
+
priv, pub, self.custom_checkpoint_observation_callable
|
305
|
+
)
|
306
|
+
if isinstance(obs_data, dict):
|
307
|
+
obs_data["engine_snapshot_data"] = engine_snapshot.model_dump()
|
308
|
+
return obs_data
|
309
|
+
|
310
|
+
# ────────────────────────────────────────────────────────────────────
|
311
|
+
# Helpers
|
312
|
+
# ────────────────────────────────────────────────────────────────────
|
313
|
+
|
314
|
+
async def _to_observation(
|
315
|
+
self,
|
316
|
+
priv: CrafterPrivateState,
|
317
|
+
pub: CrafterPublicState,
|
318
|
+
obs_cb: Optional[GetObservationCallable],
|
319
|
+
extra_obs: Optional[Dict[str, Any]] = None,
|
320
|
+
) -> InternalObservation:
|
321
|
+
# Store state before observation generation
|
322
|
+
state_before = {"private_state": priv, "public_state": pub}
|
323
|
+
|
324
|
+
active_obs_cb = obs_cb or SynthCrafterObservationCallable()
|
325
|
+
observation = await active_obs_cb.get_observation(pub, priv)
|
326
|
+
if extra_obs and isinstance(observation, dict):
|
327
|
+
observation.update(extra_obs)
|
328
|
+
|
329
|
+
# Record runtime event for observation generation
|
330
|
+
if (
|
331
|
+
self.session_tracer
|
332
|
+
and hasattr(self.session_tracer, "current_session")
|
333
|
+
and self.session_tracer.current_session
|
334
|
+
):
|
335
|
+
runtime_obs_event = RuntimeEvent()
|
336
|
+
runtime_obs_event.time_record = TimeRecord()
|
337
|
+
runtime_obs_event.time_record.event_time = time.time()
|
338
|
+
runtime_obs_event.time_record.message_time = None
|
339
|
+
runtime_obs_event.system_instance_id = "observation_generator"
|
340
|
+
runtime_obs_event.system_state_before = state_before
|
341
|
+
runtime_obs_event.system_state_after = {"observation": observation}
|
342
|
+
runtime_obs_event.metadata = {"observation_step": "state_to_obs_conversion"}
|
343
|
+
# Add directly to event history, bypassing timestep requirement
|
344
|
+
self.session_tracer.current_session.add_event(runtime_obs_event)
|
345
|
+
|
346
|
+
return observation
|
347
|
+
|
348
|
+
# ────────────────────────────────────────────────────────────────────
|
349
|
+
# ReproducibleEnvironment plumbing
|
350
|
+
# ────────────────────────────────────────────────────────────────────
|
351
|
+
|
352
|
+
async def _serialize_engine(self) -> CrafterEngineSnapshot:
|
353
|
+
return await self.engine._serialize_engine()
|
354
|
+
|
355
|
+
@classmethod
|
356
|
+
async def _deserialize_engine(
|
357
|
+
cls, snapshot: CrafterEngineSnapshot, task_instance: "CrafterTaskInstance"
|
358
|
+
) -> "CrafterClassicEnvironment":
|
359
|
+
eng = await CrafterEngine._deserialize_engine(snapshot, task_instance)
|
360
|
+
env = cls(task_instance)
|
361
|
+
env.engine = eng
|
362
|
+
# CRITICAL: Update the interact tool to use the new engine!
|
363
|
+
env._interact_tool.engine = eng
|
364
|
+
return env
|
@@ -0,0 +1,233 @@
|
|
1
|
+
"""Procedural Crafter taskset generation with seed filtering by world traits.
|
2
|
+
Run this to build a TaskInstanceSet with reproducible initial snapshots.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import random
|
9
|
+
from dataclasses import dataclass, asdict, fields
|
10
|
+
from typing import Dict, List
|
11
|
+
from uuid import UUID, uuid4
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
import crafter
|
15
|
+
from crafter import objects
|
16
|
+
|
17
|
+
from synth_ai.environments.tasks.core import (
|
18
|
+
Impetus,
|
19
|
+
Intent,
|
20
|
+
SplitInfo,
|
21
|
+
Task,
|
22
|
+
TaskInstance,
|
23
|
+
TaskInstanceMetadata,
|
24
|
+
TaskInstanceSet,
|
25
|
+
)
|
26
|
+
|
27
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
28
|
+
# Config
|
29
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
30
|
+
TASK = Task(
|
31
|
+
global_premises="Procedural Crafter seed generation",
|
32
|
+
global_constraints="",
|
33
|
+
global_objectives="Survive and unlock achievements.",
|
34
|
+
shared_env_params={},
|
35
|
+
)
|
36
|
+
|
37
|
+
AREA = (64, 64)
|
38
|
+
LEN = 10000
|
39
|
+
RADIUS = 10 # Manhattan distance for local trait count
|
40
|
+
SEED_START = 0
|
41
|
+
NUM_INSTANCES = 50
|
42
|
+
|
43
|
+
# Desired trait ranges per difficulty tier
|
44
|
+
TRAIT_BOUNDS = {
|
45
|
+
"easy": {
|
46
|
+
"min_trees": 4,
|
47
|
+
"max_hostiles": 0,
|
48
|
+
},
|
49
|
+
"medium": {
|
50
|
+
"min_trees": 2,
|
51
|
+
"max_hostiles": 2,
|
52
|
+
},
|
53
|
+
"hard": {
|
54
|
+
"min_trees": 0,
|
55
|
+
"max_hostiles": 5,
|
56
|
+
},
|
57
|
+
}
|
58
|
+
|
59
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
60
|
+
# Metadata + instance helpers
|
61
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
62
|
+
|
63
|
+
|
64
|
+
from typing import Optional
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass
|
68
|
+
class CrafterTaskInstanceMetadata(TaskInstanceMetadata):
|
69
|
+
difficulty: str
|
70
|
+
seed: int
|
71
|
+
num_trees_radius: int
|
72
|
+
num_cows_radius: int
|
73
|
+
num_hostiles_radius: int
|
74
|
+
world_config: Optional[str] = "normal" # 'easy', 'normal', 'hard', 'peaceful'
|
75
|
+
world_config_path: Optional[str] = None # Path to custom JSON config
|
76
|
+
|
77
|
+
|
78
|
+
@dataclass
|
79
|
+
class CrafterTaskInstance(TaskInstance):
|
80
|
+
async def serialize(self) -> dict: # identical to Sokoban pattern
|
81
|
+
data = asdict(self)
|
82
|
+
if isinstance(data.get("id"), UUID):
|
83
|
+
data["id"] = str(data["id"])
|
84
|
+
if "intent" in data and data["intent"] is not None:
|
85
|
+
data["intent"]["deterministic_eval_functions"] = []
|
86
|
+
return data
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
async def deserialize(cls, data: dict) -> "CrafterTaskInstance":
|
90
|
+
if "id" in data:
|
91
|
+
try:
|
92
|
+
data["id"] = UUID(str(data["id"]))
|
93
|
+
except Exception:
|
94
|
+
pass
|
95
|
+
if "impetus" in data and isinstance(data["impetus"], dict):
|
96
|
+
impetus_data = data["impetus"]
|
97
|
+
# Ensure instructions field exists with default if missing
|
98
|
+
if "instructions" not in impetus_data:
|
99
|
+
impetus_data["instructions"] = "Survive and unlock achievements"
|
100
|
+
data["impetus"] = Impetus(**impetus_data)
|
101
|
+
if "intent" in data and isinstance(data["intent"], dict):
|
102
|
+
intent_data = data["intent"]
|
103
|
+
# Ensure required fields exist with defaults if missing
|
104
|
+
if "rubric" not in intent_data:
|
105
|
+
intent_data["rubric"] = {"goal": "Unlock achievements"}
|
106
|
+
if "gold_trajectories" not in intent_data:
|
107
|
+
intent_data["gold_trajectories"] = None
|
108
|
+
if "gold_state_diff" not in intent_data:
|
109
|
+
intent_data["gold_state_diff"] = {}
|
110
|
+
intent_data["deterministic_eval_functions"] = []
|
111
|
+
data["intent"] = Intent(**intent_data)
|
112
|
+
if "metadata" in data and isinstance(data["metadata"], dict):
|
113
|
+
metadata_data = data["metadata"]
|
114
|
+
# Ensure required fields exist with defaults if missing
|
115
|
+
if "difficulty" not in metadata_data:
|
116
|
+
metadata_data["difficulty"] = "medium"
|
117
|
+
if "seed" not in metadata_data:
|
118
|
+
metadata_data["seed"] = 0
|
119
|
+
if "num_trees_radius" not in metadata_data:
|
120
|
+
metadata_data["num_trees_radius"] = 0
|
121
|
+
if "num_cows_radius" not in metadata_data:
|
122
|
+
metadata_data["num_cows_radius"] = 0
|
123
|
+
if "num_hostiles_radius" not in metadata_data:
|
124
|
+
metadata_data["num_hostiles_radius"] = 0
|
125
|
+
data["metadata"] = CrafterTaskInstanceMetadata(**metadata_data)
|
126
|
+
keep = {f.name for f in fields(cls)}
|
127
|
+
return cls(**{k: v for k, v in data.items() if k in keep})
|
128
|
+
|
129
|
+
|
130
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
131
|
+
# Trait extraction util
|
132
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
133
|
+
|
134
|
+
|
135
|
+
def world_traits(env: crafter.Env, radius: int = RADIUS) -> Dict[str, int]:
|
136
|
+
player = env._player # type: ignore[attr-defined]
|
137
|
+
pos = np.array(player.pos)
|
138
|
+
counts = {"trees": 0, "cows": 0, "hostiles": 0}
|
139
|
+
for obj in env._world._objects: # type: ignore[attr-defined]
|
140
|
+
if obj is None or obj is player:
|
141
|
+
continue
|
142
|
+
if np.abs(obj.pos - pos).sum() > radius:
|
143
|
+
continue
|
144
|
+
if isinstance(obj, objects.Plant) and getattr(obj, "kind", "") == "tree":
|
145
|
+
counts["trees"] += 1
|
146
|
+
elif isinstance(obj, objects.Cow):
|
147
|
+
counts["cows"] += 1
|
148
|
+
elif isinstance(obj, (objects.Zombie, objects.Skeleton)):
|
149
|
+
counts["hostiles"] += 1
|
150
|
+
return counts
|
151
|
+
|
152
|
+
|
153
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
154
|
+
# Main generator
|
155
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
156
|
+
|
157
|
+
|
158
|
+
async def create_crafter_taskset(num_instances: int = NUM_INSTANCES) -> TaskInstanceSet:
|
159
|
+
instances: List[CrafterTaskInstance] = []
|
160
|
+
seed = SEED_START
|
161
|
+
while len(instances) < num_instances:
|
162
|
+
env = crafter.Env(area=AREA, length=LEN, seed=seed)
|
163
|
+
_ = env.reset()
|
164
|
+
traits = world_traits(env)
|
165
|
+
# assign difficulty tier first match
|
166
|
+
difficulty: str | None = None
|
167
|
+
for diff, bounds in TRAIT_BOUNDS.items():
|
168
|
+
if (
|
169
|
+
traits["trees"] >= bounds["min_trees"]
|
170
|
+
and traits["hostiles"] <= bounds["max_hostiles"]
|
171
|
+
):
|
172
|
+
difficulty = diff
|
173
|
+
break
|
174
|
+
if difficulty is None:
|
175
|
+
seed += 1
|
176
|
+
continue
|
177
|
+
# build instance
|
178
|
+
impetus = Impetus(instructions=f"Survive and unlock achievements. Difficulty={difficulty}.")
|
179
|
+
intent = Intent(
|
180
|
+
rubric={"goal": "Unlock as many achievements as possible."},
|
181
|
+
gold_trajectories=None,
|
182
|
+
gold_state_diff={},
|
183
|
+
)
|
184
|
+
metadata = CrafterTaskInstanceMetadata(
|
185
|
+
difficulty=difficulty,
|
186
|
+
seed=seed,
|
187
|
+
num_trees_radius=traits["trees"],
|
188
|
+
num_cows_radius=traits["cows"],
|
189
|
+
num_hostiles_radius=traits["hostiles"],
|
190
|
+
)
|
191
|
+
instance = CrafterTaskInstance(
|
192
|
+
id=uuid4(),
|
193
|
+
impetus=impetus,
|
194
|
+
intent=intent,
|
195
|
+
metadata=metadata,
|
196
|
+
is_reproducible=True,
|
197
|
+
initial_engine_snapshot=None, # will be filled lazily when env starts
|
198
|
+
)
|
199
|
+
instances.append(instance)
|
200
|
+
seed += 1
|
201
|
+
|
202
|
+
# simple random split 80/10/10
|
203
|
+
random.shuffle(instances)
|
204
|
+
n = len(instances)
|
205
|
+
val_ids = {inst.id for inst in instances[int(0.8 * n) : int(0.9 * n)]}
|
206
|
+
test_ids = {inst.id for inst in instances[int(0.9 * n) :]}
|
207
|
+
split = SplitInfo(val_instance_ids=val_ids, test_instance_ids=test_ids, _is_split_defined=True)
|
208
|
+
|
209
|
+
return TaskInstanceSet(
|
210
|
+
name="Crafter Procedural TaskSet",
|
211
|
+
description="Crafter seeds filtered by local world traits around spawn.",
|
212
|
+
instances=instances,
|
213
|
+
split_info=split,
|
214
|
+
)
|
215
|
+
|
216
|
+
|
217
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
218
|
+
# CLI example
|
219
|
+
# ──────────────────────────────────────────────────────────────────────────────
|
220
|
+
|
221
|
+
if __name__ == "__main__":
|
222
|
+
import json
|
223
|
+
import pathlib
|
224
|
+
|
225
|
+
async def _main():
|
226
|
+
ts = await create_crafter_taskset(30)
|
227
|
+
serial = await asyncio.gather(*(inst.serialize() for inst in ts.instances))
|
228
|
+
out = pathlib.Path("dataset/crafter_instances.json")
|
229
|
+
out.parent.mkdir(parents=True, exist_ok=True)
|
230
|
+
out.write_text(json.dumps(serial, indent=2))
|
231
|
+
print(f"Saved {len(serial)} instances → {out}")
|
232
|
+
|
233
|
+
asyncio.run(_main())
|