synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +104 -6
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,575 @@
|
|
1
|
+
"""CrafterEngine — Stateful, reproducible wrapper around danijar/crafter.Env.
|
2
|
+
This file follows the same structure as the SokobanEngine shown earlier.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
# Import logging configuration first to suppress JAX debug messages
|
8
|
+
from .config_logging import safe_compare
|
9
|
+
|
10
|
+
# Import patches
|
11
|
+
from . import engine_deterministic_patch # Ensures deterministic behavior
|
12
|
+
from . import engine_serialization_patch_v3 as engine_serialization_patch # Adds save/load methods
|
13
|
+
from . import world_config_patch_simple as world_config_patch # Adds configurable world generation
|
14
|
+
|
15
|
+
import logging
|
16
|
+
import time
|
17
|
+
from dataclasses import dataclass
|
18
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import crafter # type: ignore
|
22
|
+
import copy
|
23
|
+
import dataclasses
|
24
|
+
|
25
|
+
from synth_ai.environments.environment.shared_engine import (
|
26
|
+
GetObservationCallable,
|
27
|
+
InternalObservation,
|
28
|
+
)
|
29
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
30
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
31
|
+
from synth_ai.environments.reproducibility.core import IReproducibleEngine
|
32
|
+
from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent # Added
|
33
|
+
|
34
|
+
# Local helper imports (must exist relative to this file)
|
35
|
+
from .engine_helpers.action_map import CRAFTER_ACTION_MAP # action‑name → int
|
36
|
+
from .engine_helpers.serialization import (
|
37
|
+
serialize_world_object,
|
38
|
+
)
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
logging.basicConfig(level=logging.INFO)
|
42
|
+
|
43
|
+
# -----------------------------------------------------------------------------
|
44
|
+
# Dataclasses for snapshot & (public, private) runtime state
|
45
|
+
# -----------------------------------------------------------------------------
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class CrafterEngineSnapshot(StatefulEngineSnapshot):
|
50
|
+
env_raw_state: Any # from crafter.Env.save()
|
51
|
+
total_reward_snapshot: float
|
52
|
+
crafter_seed: Optional[int] = None
|
53
|
+
# Store previous states needed for reward calculation to resume correctly
|
54
|
+
previous_public_state_snapshot: Optional[Dict] = None
|
55
|
+
previous_private_state_snapshot: Optional[Dict] = None
|
56
|
+
# Add _previous_public_state_for_reward and _previous_private_state_for_reward if needed for perfect resume
|
57
|
+
# For RewardStack, its configuration is fixed at init. If it had internal state, that would need saving.
|
58
|
+
|
59
|
+
|
60
|
+
@dataclass
|
61
|
+
class CrafterPublicState:
|
62
|
+
inventory: Dict[str, int]
|
63
|
+
achievements_status: Dict[str, bool]
|
64
|
+
player_position: Tuple[int, int]
|
65
|
+
player_direction: Union[int, Tuple[int, int]]
|
66
|
+
semantic_map: Optional[np.ndarray]
|
67
|
+
world_material_map: np.ndarray
|
68
|
+
observation_image: np.ndarray
|
69
|
+
num_steps_taken: int
|
70
|
+
max_steps_episode: int
|
71
|
+
error_info: Optional[str] = None
|
72
|
+
|
73
|
+
def diff(self, prev_state: "CrafterPublicState") -> Dict[str, Any]:
|
74
|
+
changes = {}
|
75
|
+
for field in self.__dataclass_fields__: # type: ignore[attr-defined]
|
76
|
+
new_v, old_v = getattr(self, field), getattr(prev_state, field)
|
77
|
+
if isinstance(new_v, np.ndarray):
|
78
|
+
if not np.array_equal(new_v, old_v):
|
79
|
+
changes[field] = True
|
80
|
+
elif new_v != old_v:
|
81
|
+
changes[field] = (old_v, new_v)
|
82
|
+
return changes
|
83
|
+
|
84
|
+
|
85
|
+
@dataclass
|
86
|
+
class CrafterPrivateState:
|
87
|
+
reward_last_step: float
|
88
|
+
total_reward_episode: float
|
89
|
+
achievements_current_values: Dict[str, int]
|
90
|
+
terminated: bool
|
91
|
+
truncated: bool
|
92
|
+
player_internal_stats: Dict[str, Any]
|
93
|
+
world_rng_state_snapshot: Any
|
94
|
+
|
95
|
+
def diff(self, prev_state: "CrafterPrivateState") -> Dict[str, Any]:
|
96
|
+
changes = {}
|
97
|
+
for field in self.__dataclass_fields__: # type: ignore[attr-defined]
|
98
|
+
new_v, old_v = getattr(self, field), getattr(prev_state, field)
|
99
|
+
if new_v != old_v:
|
100
|
+
changes[field] = (old_v, new_v)
|
101
|
+
return changes
|
102
|
+
|
103
|
+
|
104
|
+
# -----------------------------------------------------------------------------
|
105
|
+
# Observation helpers
|
106
|
+
# -----------------------------------------------------------------------------
|
107
|
+
|
108
|
+
|
109
|
+
class CrafterObservationCallable(GetObservationCallable):
|
110
|
+
def __init__(self) -> None:
|
111
|
+
pass
|
112
|
+
|
113
|
+
async def get_observation(
|
114
|
+
self, pub: CrafterPublicState, priv: CrafterPrivateState
|
115
|
+
) -> InternalObservation: # type: ignore[override]
|
116
|
+
observation: Dict[str, Any] = {
|
117
|
+
"inventory": pub.inventory,
|
118
|
+
"achievements": pub.achievements_status,
|
119
|
+
"player_pos": pub.player_position,
|
120
|
+
"steps": pub.num_steps_taken,
|
121
|
+
"reward_last": priv.reward_last_step,
|
122
|
+
"total_reward": priv.total_reward_episode,
|
123
|
+
"terminated": priv.terminated,
|
124
|
+
"truncated": priv.truncated,
|
125
|
+
}
|
126
|
+
return observation # type: ignore[return-value]
|
127
|
+
|
128
|
+
|
129
|
+
# -----------------------------------------------------------------------------
|
130
|
+
# CrafterEngine implementation
|
131
|
+
# -----------------------------------------------------------------------------
|
132
|
+
|
133
|
+
|
134
|
+
class CrafterEngine(StatefulEngine, IReproducibleEngine):
|
135
|
+
"""StatefulEngine wrapper around `crafter.Env` supporting full snapshotting."""
|
136
|
+
|
137
|
+
task_instance: TaskInstance
|
138
|
+
env: crafter.Env
|
139
|
+
|
140
|
+
# ────────────────────────────────────────────────────────────────────────
|
141
|
+
# Construction helpers
|
142
|
+
# ────────────────────────────────────────────────────────────────────────
|
143
|
+
|
144
|
+
def __init__(self, task_instance: TaskInstance):
|
145
|
+
self.task_instance = task_instance
|
146
|
+
self._total_reward: float = 0.0
|
147
|
+
self._current_action_for_reward: Optional[int] = None
|
148
|
+
self._previous_public_state_for_reward: Optional[CrafterPublicState] = None
|
149
|
+
self._previous_private_state_for_reward: Optional[CrafterPrivateState] = (
|
150
|
+
None # For stat changes
|
151
|
+
)
|
152
|
+
|
153
|
+
# Initialize achievements tracking
|
154
|
+
self.achievements_unlocked: set = set()
|
155
|
+
|
156
|
+
cfg = getattr(task_instance, "config", {}) or {}
|
157
|
+
area: Tuple[int, int] = tuple(cfg.get("area", (64, 64))) # type: ignore[arg-type]
|
158
|
+
length: int = int(cfg.get("length", 10000))
|
159
|
+
|
160
|
+
# Get seed from metadata if available, otherwise fall back to config
|
161
|
+
seed: Optional[int] = cfg.get("seed")
|
162
|
+
if hasattr(task_instance, "metadata") and hasattr(task_instance.metadata, "seed"):
|
163
|
+
seed = task_instance.metadata.seed
|
164
|
+
|
165
|
+
# Get world configuration from metadata or config
|
166
|
+
world_config = "normal" # default
|
167
|
+
world_config_path = None
|
168
|
+
|
169
|
+
if hasattr(task_instance, "metadata") and hasattr(task_instance.metadata, "world_config"):
|
170
|
+
world_config = task_instance.metadata.world_config
|
171
|
+
logger.info(f"CrafterEngine: Using world_config from metadata: {world_config}")
|
172
|
+
elif cfg.get("world_config"):
|
173
|
+
world_config = cfg.get("world_config")
|
174
|
+
logger.info(f"CrafterEngine: Using world_config from cfg: {world_config}")
|
175
|
+
|
176
|
+
if hasattr(task_instance, "metadata") and hasattr(
|
177
|
+
task_instance.metadata, "world_config_path"
|
178
|
+
):
|
179
|
+
world_config_path = task_instance.metadata.world_config_path
|
180
|
+
elif cfg.get("world_config_path"):
|
181
|
+
world_config_path = cfg.get("world_config_path")
|
182
|
+
|
183
|
+
logger.info(f"CrafterEngine: Creating Env with world_config={world_config}, seed={seed}")
|
184
|
+
self.env = crafter.Env(
|
185
|
+
area=area,
|
186
|
+
length=length,
|
187
|
+
seed=seed,
|
188
|
+
world_config=world_config,
|
189
|
+
world_config_path=world_config_path,
|
190
|
+
)
|
191
|
+
# store original seed for reproducibility
|
192
|
+
self.env._seed = seed
|
193
|
+
|
194
|
+
self.reward_stack = RewardStack(
|
195
|
+
components=[
|
196
|
+
CrafterAchievementComponent(),
|
197
|
+
CrafterPlayerStatComponent(),
|
198
|
+
CrafterStepPenaltyComponent(penalty=-0.001),
|
199
|
+
]
|
200
|
+
)
|
201
|
+
|
202
|
+
# ────────────────────────────────────────────────────────────────────────
|
203
|
+
# Utility: action validation / mapping
|
204
|
+
# ────────────────────────────────────────────────────────────────────────
|
205
|
+
|
206
|
+
def _validate_action_engine(self, action: Union[int, str]) -> int: # type: ignore[override]
|
207
|
+
if isinstance(action, str):
|
208
|
+
action = CRAFTER_ACTION_MAP.get(action, 0)
|
209
|
+
if not isinstance(action, int):
|
210
|
+
return 0
|
211
|
+
return int(np.clip(action, 0, len(crafter.constants.actions) - 1)) # type: ignore
|
212
|
+
|
213
|
+
# ────────────────────────────────────────────────────────────────────────
|
214
|
+
# Core StatefulEngine API
|
215
|
+
# ────────────────────────────────────────────────────────────────────────
|
216
|
+
|
217
|
+
async def _reset_engine(
|
218
|
+
self, *, seed: Optional[int] | None = None
|
219
|
+
) -> Tuple[CrafterPrivateState, CrafterPublicState]:
|
220
|
+
if seed is not None:
|
221
|
+
# Re‑instantiate env with new seed to match crafter's internal reseeding convention
|
222
|
+
self.env = crafter.Env(area=self.env._area, length=self.env._length, seed=seed)
|
223
|
+
obs_img = self.env.reset()
|
224
|
+
self._total_reward = 0.0
|
225
|
+
pub = self._build_public_state(obs_img)
|
226
|
+
priv = self._build_private_state(reward=0.0, terminated=False, truncated=False)
|
227
|
+
|
228
|
+
# Player starting position tracked internally
|
229
|
+
|
230
|
+
return priv, pub
|
231
|
+
|
232
|
+
async def _step_engine(self, action: int) -> Tuple[CrafterPrivateState, CrafterPublicState]:
|
233
|
+
step_start_time = time.time()
|
234
|
+
try:
|
235
|
+
# Validate action is in valid range
|
236
|
+
if action < 0 or action >= self.env.action_space.n:
|
237
|
+
raise ValueError(
|
238
|
+
f"Invalid action {action}, must be in range [0, {self.env.action_space.n})"
|
239
|
+
)
|
240
|
+
|
241
|
+
# Ensure player reference is valid before proceeding
|
242
|
+
if self.env._player is None:
|
243
|
+
# Try to find player in world objects
|
244
|
+
for obj in self.env._world._objects:
|
245
|
+
if (
|
246
|
+
obj is not None
|
247
|
+
and hasattr(obj, "__class__")
|
248
|
+
and obj.__class__.__name__ == "Player"
|
249
|
+
):
|
250
|
+
self.env._player = obj
|
251
|
+
break
|
252
|
+
|
253
|
+
if self.env._player is None:
|
254
|
+
raise RuntimeError("Player object not found in world")
|
255
|
+
|
256
|
+
# Build current public state for reward calculation
|
257
|
+
current_pub_state = self._build_public_state(self.env.render())
|
258
|
+
|
259
|
+
# Step the environment
|
260
|
+
crafter_step_start = time.time()
|
261
|
+
obs, reward, done, info = self.env.step(action)
|
262
|
+
crafter_step_time = time.time() - crafter_step_start
|
263
|
+
logger.debug(f"Crafter env.step() took {crafter_step_time:.3f}s")
|
264
|
+
|
265
|
+
# Update internal state
|
266
|
+
self.obs = obs
|
267
|
+
self.done = done
|
268
|
+
self.info = info
|
269
|
+
self.last_reward = reward
|
270
|
+
|
271
|
+
# Step count is tracked by the crafter environment itself in self.env._step
|
272
|
+
|
273
|
+
# Process achievements - check what was unlocked this step
|
274
|
+
new_achievements = set()
|
275
|
+
if "achievements" in info:
|
276
|
+
for achievement, status in info["achievements"].items():
|
277
|
+
if status and achievement not in self.achievements_unlocked:
|
278
|
+
new_achievements.add(achievement)
|
279
|
+
self.achievements_unlocked.add(achievement)
|
280
|
+
|
281
|
+
# Calculate reward
|
282
|
+
reward_from_stack = 0
|
283
|
+
try:
|
284
|
+
if hasattr(self, "_reward_stack") and self._reward_stack:
|
285
|
+
reward_from_stack = sum(self._reward_stack)
|
286
|
+
self._reward_stack.clear()
|
287
|
+
except Exception as e:
|
288
|
+
reward_from_stack = 0
|
289
|
+
|
290
|
+
# Create private state
|
291
|
+
# Current episode reward
|
292
|
+
final_reward = self._total_reward + reward + reward_from_stack
|
293
|
+
self._total_reward = final_reward
|
294
|
+
|
295
|
+
# Determine proper termination reason based on game state
|
296
|
+
player = self.env._player # type: ignore[attr-defined]
|
297
|
+
current_step = self.env._step # type: ignore[attr-defined]
|
298
|
+
max_steps = self.env._length # type: ignore[attr-defined]
|
299
|
+
|
300
|
+
# Check if player died (health <= 0)
|
301
|
+
player_died = player.health <= 0
|
302
|
+
|
303
|
+
# Check if max steps reached
|
304
|
+
max_steps_reached = current_step >= max_steps
|
305
|
+
|
306
|
+
# Set termination flags properly:
|
307
|
+
# - terminated=True only if player actually died
|
308
|
+
# - truncated=True only if episode ended due to step limit
|
309
|
+
if done:
|
310
|
+
if player_died:
|
311
|
+
terminated = True
|
312
|
+
truncated = False
|
313
|
+
elif max_steps_reached:
|
314
|
+
terminated = False
|
315
|
+
truncated = True
|
316
|
+
else:
|
317
|
+
# Fallback: if done=True but unclear reason, assume timeout
|
318
|
+
terminated = False
|
319
|
+
truncated = True
|
320
|
+
else:
|
321
|
+
terminated = False
|
322
|
+
truncated = False
|
323
|
+
|
324
|
+
final_priv_state = self._build_private_state(final_reward, terminated, truncated)
|
325
|
+
|
326
|
+
self._previous_public_state_for_reward = current_pub_state
|
327
|
+
self._previous_private_state_for_reward = final_priv_state
|
328
|
+
|
329
|
+
total_step_time = time.time() - step_start_time
|
330
|
+
logger.debug(
|
331
|
+
f"CrafterEngine _step_engine took {total_step_time:.3f}s (crafter.step: {crafter_step_time:.3f}s)"
|
332
|
+
)
|
333
|
+
return final_priv_state, current_pub_state
|
334
|
+
|
335
|
+
except Exception as e:
|
336
|
+
# Create error state
|
337
|
+
import traceback
|
338
|
+
|
339
|
+
logger.error(f"Step engine error: {e}")
|
340
|
+
logger.error(traceback.format_exc())
|
341
|
+
error_pub_state = self._get_public_state_from_env()
|
342
|
+
error_pub_state.error_info = f"Step engine error: {e}"
|
343
|
+
error_priv_state = self._get_private_state_from_env(
|
344
|
+
reward=-1.0, terminated=True, truncated=False
|
345
|
+
)
|
346
|
+
return error_priv_state, error_pub_state
|
347
|
+
|
348
|
+
# ------------------------------------------------------------------
|
349
|
+
# Rendering (simple text summary)
|
350
|
+
# ------------------------------------------------------------------
|
351
|
+
|
352
|
+
async def _render(
|
353
|
+
self,
|
354
|
+
private_state: CrafterPrivateState,
|
355
|
+
public_state: CrafterPublicState,
|
356
|
+
get_observation: Optional[GetObservationCallable] = None,
|
357
|
+
) -> str: # type: ignore[override]
|
358
|
+
obs_cb = get_observation or CrafterObservationCallable()
|
359
|
+
obs = await obs_cb.get_observation(public_state, private_state)
|
360
|
+
if isinstance(obs, str):
|
361
|
+
return obs
|
362
|
+
if isinstance(obs, dict):
|
363
|
+
header = f"steps: {public_state.num_steps_taken}/{public_state.max_steps_episode} | "
|
364
|
+
header += f"last_r: {private_state.reward_last_step:.2f} | total_r: {private_state.total_reward_episode:.2f}"
|
365
|
+
inv = ", ".join(f"{k}:{v}" for k, v in public_state.inventory.items() if v)
|
366
|
+
ach = ", ".join(k for k, v in public_state.achievements_status.items() if v)
|
367
|
+
return f"{header}\ninv: {inv}\nach: {ach}"
|
368
|
+
return str(obs)
|
369
|
+
|
370
|
+
# ------------------------------------------------------------------
|
371
|
+
# Snapshotting for exact reproducibility
|
372
|
+
# ------------------------------------------------------------------
|
373
|
+
|
374
|
+
async def _serialize_engine(self) -> CrafterEngineSnapshot:
|
375
|
+
world = self.env._world # type: ignore[attr-defined]
|
376
|
+
objects_state = [None if o is None else serialize_world_object(o) for o in world._objects]
|
377
|
+
# capture total reward and original seed
|
378
|
+
total_reward = self._total_reward
|
379
|
+
snap = CrafterEngineSnapshot(
|
380
|
+
env_raw_state=self.env.save(),
|
381
|
+
total_reward_snapshot=total_reward,
|
382
|
+
crafter_seed=self.env._seed,
|
383
|
+
previous_public_state_snapshot=dataclasses.asdict(
|
384
|
+
self._previous_public_state_for_reward
|
385
|
+
)
|
386
|
+
if self._previous_public_state_for_reward
|
387
|
+
else None,
|
388
|
+
previous_private_state_snapshot=dataclasses.asdict(
|
389
|
+
self._previous_private_state_for_reward
|
390
|
+
)
|
391
|
+
if self._previous_private_state_for_reward
|
392
|
+
else None,
|
393
|
+
)
|
394
|
+
return snap
|
395
|
+
|
396
|
+
@classmethod
|
397
|
+
async def _deserialize_engine(
|
398
|
+
cls, snapshot: CrafterEngineSnapshot, task_instance: TaskInstance
|
399
|
+
) -> "CrafterEngine":
|
400
|
+
engine = cls(task_instance)
|
401
|
+
# Initialize env first to create structures
|
402
|
+
obs = engine.env.reset()
|
403
|
+
# Then load the saved state (this overrides the reset)
|
404
|
+
engine.env.load(snapshot.env_raw_state)
|
405
|
+
engine._total_reward = snapshot.total_reward_snapshot
|
406
|
+
engine.env._seed = snapshot.crafter_seed
|
407
|
+
|
408
|
+
# Initialize engine state attributes that step() expects
|
409
|
+
engine.obs = engine.env.render()
|
410
|
+
engine.done = False
|
411
|
+
engine.info = {}
|
412
|
+
engine.last_reward = 0.0
|
413
|
+
|
414
|
+
# Ensure achievements tracking is initialized
|
415
|
+
engine.achievements_unlocked = set()
|
416
|
+
|
417
|
+
# Re-establish previous states for reward system continuity if first step after load
|
418
|
+
engine._previous_public_state_for_reward = engine._build_public_state(engine.env.render())
|
419
|
+
# Safe comparisons to avoid string vs int errors
|
420
|
+
health_dead = safe_compare(0, engine.env._player.health, ">=")
|
421
|
+
step_exceeded = safe_compare(engine.env._length, engine.env._step, "<=")
|
422
|
+
engine._previous_private_state_for_reward = engine._build_private_state(
|
423
|
+
0.0, health_dead, step_exceeded
|
424
|
+
)
|
425
|
+
return engine
|
426
|
+
|
427
|
+
# ------------------------------------------------------------------
|
428
|
+
# Internal helpers
|
429
|
+
# ------------------------------------------------------------------
|
430
|
+
|
431
|
+
def _build_public_state(
|
432
|
+
self, obs_img: np.ndarray, info: Optional[Dict[str, Any]] | None = None
|
433
|
+
) -> CrafterPublicState:
|
434
|
+
try:
|
435
|
+
if info is None:
|
436
|
+
player = self.env._player # type: ignore[attr-defined]
|
437
|
+
# Safe achievement status check
|
438
|
+
achievements_status = {}
|
439
|
+
for k, v in player.achievements.items():
|
440
|
+
achievements_status[k] = safe_compare(0, v, "<")
|
441
|
+
inventory = player.inventory.copy()
|
442
|
+
semantic = getattr(self.env, "_sem_view", lambda: None)()
|
443
|
+
else:
|
444
|
+
inventory = info.get("inventory", {})
|
445
|
+
# Safe achievement status check from info
|
446
|
+
achievements_status = {}
|
447
|
+
achievements_info = info.get("achievements", {})
|
448
|
+
for k, v in achievements_info.items():
|
449
|
+
achievements_status[k] = safe_compare(0, v, "<")
|
450
|
+
semantic = info.get("semantic")
|
451
|
+
|
452
|
+
player = self.env._player # type: ignore[attr-defined]
|
453
|
+
return CrafterPublicState(
|
454
|
+
inventory=inventory,
|
455
|
+
achievements_status=achievements_status,
|
456
|
+
player_position=tuple(player.pos), # type: ignore[attr-defined]
|
457
|
+
player_direction=player.facing, # type: ignore[attr-defined]
|
458
|
+
semantic_map=semantic,
|
459
|
+
world_material_map=self.env._world._mat_map.copy(), # type: ignore[attr-defined]
|
460
|
+
observation_image=obs_img,
|
461
|
+
num_steps_taken=self.env._step, # type: ignore[attr-defined]
|
462
|
+
max_steps_episode=self.env._length, # type: ignore[attr-defined]
|
463
|
+
error_info=info.get("error_info") if info else None,
|
464
|
+
)
|
465
|
+
except Exception as e:
|
466
|
+
logging.error(f"Error building public state: {e}")
|
467
|
+
# Return minimal safe state
|
468
|
+
return CrafterPublicState(
|
469
|
+
inventory={},
|
470
|
+
achievements_status={},
|
471
|
+
player_position=(0, 0),
|
472
|
+
player_direction=0,
|
473
|
+
semantic_map=None,
|
474
|
+
world_material_map=np.zeros((1, 1), dtype=np.uint8),
|
475
|
+
observation_image=obs_img
|
476
|
+
if obs_img is not None
|
477
|
+
else np.zeros((64, 64, 3), dtype=np.uint8),
|
478
|
+
num_steps_taken=0,
|
479
|
+
max_steps_episode=10000,
|
480
|
+
error_info=f"State building error: {e}",
|
481
|
+
)
|
482
|
+
|
483
|
+
def _build_private_state(
|
484
|
+
self, reward: float, terminated: bool, truncated: bool
|
485
|
+
) -> CrafterPrivateState:
|
486
|
+
player = self.env._player # type: ignore[attr-defined]
|
487
|
+
stats = {
|
488
|
+
"health": player.health,
|
489
|
+
"food": player.inventory.get("food"),
|
490
|
+
"drink": player.inventory.get("drink"),
|
491
|
+
"energy": player.inventory.get("energy"),
|
492
|
+
"_hunger": getattr(player, "_hunger", 0),
|
493
|
+
"_thirst": getattr(player, "_thirst", 0),
|
494
|
+
}
|
495
|
+
return CrafterPrivateState(
|
496
|
+
reward_last_step=reward,
|
497
|
+
total_reward_episode=self._total_reward,
|
498
|
+
achievements_current_values=player.achievements.copy(),
|
499
|
+
terminated=terminated,
|
500
|
+
truncated=truncated,
|
501
|
+
player_internal_stats=stats,
|
502
|
+
world_rng_state_snapshot=self.env._world.random.get_state(), # type: ignore[attr-defined]
|
503
|
+
)
|
504
|
+
|
505
|
+
def _get_public_state_from_env(self) -> CrafterPublicState:
|
506
|
+
"""Helper method to get current public state from synth_ai.environments.environment"""
|
507
|
+
try:
|
508
|
+
obs_img = self.env.render()
|
509
|
+
return self._build_public_state(obs_img)
|
510
|
+
except Exception as e:
|
511
|
+
logging.error(f"Error getting public state from env: {e}")
|
512
|
+
# Return default state
|
513
|
+
return CrafterPublicState(
|
514
|
+
inventory={},
|
515
|
+
achievements_status={},
|
516
|
+
player_position=(0, 0),
|
517
|
+
player_direction=0,
|
518
|
+
semantic_map=None,
|
519
|
+
world_material_map=np.zeros((1, 1), dtype=np.uint8),
|
520
|
+
observation_image=np.zeros((64, 64, 3), dtype=np.uint8),
|
521
|
+
num_steps_taken=0,
|
522
|
+
max_steps_episode=10000,
|
523
|
+
error_info=f"State extraction error: {e}",
|
524
|
+
)
|
525
|
+
|
526
|
+
def _get_private_state_from_env(
|
527
|
+
self, reward: float, terminated: bool, truncated: bool
|
528
|
+
) -> CrafterPrivateState:
|
529
|
+
"""Helper method to get current private state from synth_ai.environments.environment"""
|
530
|
+
try:
|
531
|
+
return self._build_private_state(reward, terminated, truncated)
|
532
|
+
except Exception as e:
|
533
|
+
logging.error(f"Error getting private state from env: {e}")
|
534
|
+
# Return default state
|
535
|
+
return CrafterPrivateState(
|
536
|
+
reward_last_step=reward,
|
537
|
+
total_reward_episode=0.0,
|
538
|
+
achievements_current_values={},
|
539
|
+
terminated=terminated,
|
540
|
+
truncated=truncated,
|
541
|
+
player_internal_stats={},
|
542
|
+
world_rng_state_snapshot=None,
|
543
|
+
)
|
544
|
+
|
545
|
+
|
546
|
+
# --- Reward Components ---
|
547
|
+
class CrafterAchievementComponent(RewardComponent):
|
548
|
+
async def score(self, state: CrafterPublicState, action: Dict[str, Any]) -> float:
|
549
|
+
prev_achievements = action.get("previous_public_state_achievements", {})
|
550
|
+
current_achievements = state.achievements_status
|
551
|
+
new_achievements = sum(
|
552
|
+
1
|
553
|
+
for ach, status in current_achievements.items()
|
554
|
+
if status and not prev_achievements.get(ach)
|
555
|
+
)
|
556
|
+
return float(new_achievements) * 0.1
|
557
|
+
|
558
|
+
|
559
|
+
class CrafterPlayerStatComponent(RewardComponent):
|
560
|
+
async def score(self, state: CrafterPrivateState, action: Dict[str, Any]) -> float:
|
561
|
+
current_health = state.player_internal_stats.get("health", 0)
|
562
|
+
prev_health = action.get("previous_private_state_stats", {}).get("health", current_health)
|
563
|
+
if current_health < prev_health:
|
564
|
+
return -0.05 # Lost health penalty
|
565
|
+
return 0.0
|
566
|
+
|
567
|
+
|
568
|
+
class CrafterStepPenaltyComponent(RewardComponent):
|
569
|
+
def __init__(self, penalty: float = -0.001):
|
570
|
+
super().__init__()
|
571
|
+
self.penalty = penalty
|
572
|
+
self.weight = 1.0
|
573
|
+
|
574
|
+
async def score(self, state: Any, action: Any) -> float:
|
575
|
+
return self.penalty
|
@@ -0,0 +1,63 @@
|
|
1
|
+
"""
|
2
|
+
Apply once (import this module anywhere before CrafterEngine is used).
|
3
|
+
It replaces Env._balance_object so that every per-chunk object list is
|
4
|
+
sorted by (x, y, class-name) before any random choice is made – removing
|
5
|
+
the hash-based set-iteration nondeterminism that caused the drift.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import collections
|
9
|
+
import crafter
|
10
|
+
|
11
|
+
print("[PATCH] Attempting to apply Crafter deterministic patch...")
|
12
|
+
|
13
|
+
# -----------------------------------------------------------------------------
|
14
|
+
# 1. Make per–chunk object order stable
|
15
|
+
# -----------------------------------------------------------------------------
|
16
|
+
if not hasattr(crafter.Env, "_orig_balance_object"):
|
17
|
+
print("[PATCH] Patching crafter.Env._balance_object...")
|
18
|
+
crafter.Env._orig_balance_object = crafter.Env._balance_object
|
19
|
+
|
20
|
+
def _balance_object_det(self, chunk, objs, *args, **kwargs):
|
21
|
+
# cls, material, span_dist, despan_dist, spawn_prob, despawn_prob, ctor, target_fn
|
22
|
+
# were part of the original signature, but *args, **kwargs is more robust.
|
23
|
+
objs = sorted(objs, key=lambda o: (o.pos[0], o.pos[1], o.__class__.__name__))
|
24
|
+
return crafter.Env._orig_balance_object(self, chunk, objs, *args, **kwargs)
|
25
|
+
|
26
|
+
crafter.Env._balance_object = _balance_object_det
|
27
|
+
print("[PATCH] crafter.Env._balance_object patched.")
|
28
|
+
else:
|
29
|
+
print("[PATCH] crafter.Env._balance_object already patched or _orig_balance_object exists.")
|
30
|
+
|
31
|
+
# -----------------------------------------------------------------------------
|
32
|
+
# 2. Make *chunk* iteration order stable
|
33
|
+
# -----------------------------------------------------------------------------
|
34
|
+
if not hasattr(crafter.engine.World, "_orig_chunks_prop"):
|
35
|
+
crafter.engine.World._orig_chunks_prop = crafter.engine.World.chunks
|
36
|
+
|
37
|
+
def _chunks_sorted(self):
|
38
|
+
# OrderedDict keeps the sorted key order during iteration
|
39
|
+
return collections.OrderedDict(sorted(self._chunks.items()))
|
40
|
+
|
41
|
+
crafter.engine.World.chunks = property(_chunks_sorted)
|
42
|
+
|
43
|
+
# -----------------------------------------------------------------------------
|
44
|
+
# 3. NEW: keep per-frame object update order deterministic
|
45
|
+
# -----------------------------------------------------------------------------
|
46
|
+
if not hasattr(crafter.engine.World, "_orig_objects_prop"):
|
47
|
+
crafter.engine.World._orig_objects_prop = crafter.engine.World.objects # save original
|
48
|
+
|
49
|
+
@property
|
50
|
+
def _objects_sorted(self):
|
51
|
+
objs = [o for o in self._objects if o] # Filter out None (removed) objects
|
52
|
+
# stable order: x, y, class-name, creation-index
|
53
|
+
return sorted(
|
54
|
+
objs,
|
55
|
+
key=lambda o: (
|
56
|
+
o.pos[0],
|
57
|
+
o.pos[1],
|
58
|
+
o.__class__.__name__,
|
59
|
+
getattr(o, "_id", 0),
|
60
|
+
),
|
61
|
+
)
|
62
|
+
|
63
|
+
crafter.engine.World.objects = _objects_sorted
|