synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +579 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/environments/examples/wordle/__init__.py +29 -0
  100. synth_ai/environments/examples/wordle/engine.py +391 -0
  101. synth_ai/environments/examples/wordle/environment.py +154 -0
  102. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  103. synth_ai/environments/examples/wordle/taskset.py +222 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/environments/service/core_routes.py +38 -0
  106. synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
  107. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
  108. synth_ai/learning/prompts/mipro.py +273 -1
  109. synth_ai/learning/prompts/random_search.py +247 -0
  110. synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
  111. synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
  112. synth_ai/lm/injection.py +81 -0
  113. synth_ai/lm/overrides.py +204 -0
  114. synth_ai/lm/provider_support/anthropic.py +39 -12
  115. synth_ai/lm/provider_support/openai.py +31 -4
  116. synth_ai/lm/vendors/core/anthropic_api.py +16 -0
  117. synth_ai/lm/vendors/openai_standard.py +35 -5
  118. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
  119. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
  120. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
  121. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
  122. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
  123. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,579 @@
1
+ """CrafterEngine — Stateful, reproducible wrapper around danijar/crafter.Env.
2
+ This file follows the same structure as the SokobanEngine shown earlier.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ # Import logging configuration first to suppress JAX debug messages
8
+ from .config_logging import safe_compare
9
+
10
+ # Import patches
11
+ from . import engine_deterministic_patch # Ensures deterministic behavior
12
+ from . import engine_serialization_patch_v3 as engine_serialization_patch # Adds save/load methods
13
+ from . import world_config_patch_simple as world_config_patch # Adds configurable world generation
14
+
15
+ import logging
16
+ import time
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import crafter # type: ignore
22
+ import copy
23
+ import dataclasses
24
+
25
+ from synth_ai.environments.environment.shared_engine import (
26
+ GetObservationCallable,
27
+ InternalObservation,
28
+ )
29
+ from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
30
+ from synth_ai.environments.tasks.core import TaskInstance
31
+ from synth_ai.environments.reproducibility.core import IReproducibleEngine
32
+ from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent # Added
33
+
34
+ # Local helper imports (must exist relative to this file)
35
+ from .engine_helpers.action_map import CRAFTER_ACTION_MAP # action‑name → int
36
+ from .engine_helpers.serialization import (
37
+ serialize_world_object,
38
+ )
39
+
40
+ logger = logging.getLogger(__name__)
41
+ logging.basicConfig(level=logging.INFO)
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # Dataclasses for snapshot & (public, private) runtime state
45
+ # -----------------------------------------------------------------------------
46
+
47
+
48
+ @dataclass
49
+ class CrafterEngineSnapshot(StatefulEngineSnapshot):
50
+ env_raw_state: Any # from crafter.Env.save()
51
+ total_reward_snapshot: float
52
+ crafter_seed: Optional[int] = None
53
+ # Store previous states needed for reward calculation to resume correctly
54
+ previous_public_state_snapshot: Optional[Dict] = None
55
+ previous_private_state_snapshot: Optional[Dict] = None
56
+ # Add _previous_public_state_for_reward and _previous_private_state_for_reward if needed for perfect resume
57
+ # For RewardStack, its configuration is fixed at init. If it had internal state, that would need saving.
58
+
59
+
60
+ @dataclass
61
+ class CrafterPublicState:
62
+ inventory: Dict[str, int]
63
+ achievements_status: Dict[str, bool]
64
+ player_position: Tuple[int, int]
65
+ player_direction: Union[int, Tuple[int, int]]
66
+ semantic_map: Optional[np.ndarray]
67
+ world_material_map: np.ndarray
68
+ observation_image: np.ndarray
69
+ num_steps_taken: int
70
+ max_steps_episode: int
71
+ error_info: Optional[str] = None
72
+
73
+ def diff(self, prev_state: "CrafterPublicState") -> Dict[str, Any]:
74
+ changes = {}
75
+ for field in self.__dataclass_fields__: # type: ignore[attr-defined]
76
+ new_v, old_v = getattr(self, field), getattr(prev_state, field)
77
+ if isinstance(new_v, np.ndarray):
78
+ if not np.array_equal(new_v, old_v):
79
+ changes[field] = True
80
+ elif new_v != old_v:
81
+ changes[field] = (old_v, new_v)
82
+ return changes
83
+
84
+
85
+ @dataclass
86
+ class CrafterPrivateState:
87
+ reward_last_step: float
88
+ total_reward_episode: float
89
+ achievements_current_values: Dict[str, int]
90
+ terminated: bool
91
+ truncated: bool
92
+ player_internal_stats: Dict[str, Any]
93
+ world_rng_state_snapshot: Any
94
+
95
+ def diff(self, prev_state: "CrafterPrivateState") -> Dict[str, Any]:
96
+ changes = {}
97
+ for field in self.__dataclass_fields__: # type: ignore[attr-defined]
98
+ new_v, old_v = getattr(self, field), getattr(prev_state, field)
99
+ if new_v != old_v:
100
+ changes[field] = (old_v, new_v)
101
+ return changes
102
+
103
+
104
+ # -----------------------------------------------------------------------------
105
+ # Observation helpers
106
+ # -----------------------------------------------------------------------------
107
+
108
+
109
+ class CrafterObservationCallable(GetObservationCallable):
110
+ def __init__(self) -> None:
111
+ pass
112
+
113
+ async def get_observation(
114
+ self, pub: CrafterPublicState, priv: CrafterPrivateState
115
+ ) -> InternalObservation: # type: ignore[override]
116
+ observation: Dict[str, Any] = {
117
+ "inventory": pub.inventory,
118
+ "achievements": pub.achievements_status,
119
+ "player_pos": pub.player_position,
120
+ "steps": pub.num_steps_taken,
121
+ "reward_last": priv.reward_last_step,
122
+ "total_reward": priv.total_reward_episode,
123
+ "terminated": priv.terminated,
124
+ "truncated": priv.truncated,
125
+ }
126
+ return observation # type: ignore[return-value]
127
+
128
+
129
+ # -----------------------------------------------------------------------------
130
+ # CrafterEngine implementation
131
+ # -----------------------------------------------------------------------------
132
+
133
+
134
+ class CrafterEngine(StatefulEngine, IReproducibleEngine):
135
+ """StatefulEngine wrapper around `crafter.Env` supporting full snapshotting."""
136
+
137
+ task_instance: TaskInstance
138
+ env: crafter.Env
139
+
140
+ # ────────────────────────────────────────────────────────────────────────
141
+ # Construction helpers
142
+ # ────────────────────────────────────────────────────────────────────────
143
+
144
+ def __init__(self, task_instance: TaskInstance):
145
+ self.task_instance = task_instance
146
+ self._total_reward: float = 0.0
147
+ self._current_action_for_reward: Optional[int] = None
148
+ self._previous_public_state_for_reward: Optional[CrafterPublicState] = None
149
+ self._previous_private_state_for_reward: Optional[CrafterPrivateState] = (
150
+ None # For stat changes
151
+ )
152
+
153
+ # Initialize achievements tracking
154
+ self.achievements_unlocked: set = set()
155
+
156
+ cfg = getattr(task_instance, "config", {}) or {}
157
+ area: Tuple[int, int] = tuple(cfg.get("area", (64, 64))) # type: ignore[arg-type]
158
+ length: int = int(cfg.get("length", 10000))
159
+
160
+ # Get seed from metadata if available, otherwise fall back to config
161
+ seed: Optional[int] = cfg.get("seed")
162
+ if hasattr(task_instance, "metadata") and hasattr(task_instance.metadata, "seed"):
163
+ seed = task_instance.metadata.seed
164
+
165
+ # Get world configuration from metadata or config
166
+ world_config = "normal" # default
167
+ world_config_path = None
168
+
169
+ if hasattr(task_instance, "metadata") and hasattr(task_instance.metadata, "world_config"):
170
+ world_config = task_instance.metadata.world_config
171
+ logger.info(f"CrafterEngine: Using world_config from metadata: {world_config}")
172
+ elif cfg.get("world_config"):
173
+ world_config = cfg.get("world_config")
174
+ logger.info(f"CrafterEngine: Using world_config from cfg: {world_config}")
175
+
176
+ if hasattr(task_instance, "metadata") and hasattr(
177
+ task_instance.metadata, "world_config_path"
178
+ ):
179
+ world_config_path = task_instance.metadata.world_config_path
180
+ elif cfg.get("world_config_path"):
181
+ world_config_path = cfg.get("world_config_path")
182
+
183
+ logger.info(f"CrafterEngine: Creating Env with world_config={world_config}, seed={seed}")
184
+ self.env = crafter.Env(
185
+ area=area,
186
+ length=length,
187
+ seed=seed,
188
+ world_config=world_config,
189
+ world_config_path=world_config_path,
190
+ )
191
+ # store original seed for reproducibility
192
+ self.env._seed = seed
193
+
194
+ self.reward_stack = RewardStack(
195
+ components=[
196
+ CrafterAchievementComponent(),
197
+ CrafterPlayerStatComponent(),
198
+ CrafterStepPenaltyComponent(penalty=-0.001),
199
+ ]
200
+ )
201
+
202
+ # ────────────────────────────────────────────────────────────────────────
203
+ # Utility: action validation / mapping
204
+ # ────────────────────────────────────────────────────────────────────────
205
+
206
+ def _validate_action_engine(self, action: Union[int, str]) -> int: # type: ignore[override]
207
+ if isinstance(action, str):
208
+ action = CRAFTER_ACTION_MAP.get(action, 0)
209
+ if not isinstance(action, int):
210
+ return 0
211
+ return int(np.clip(action, 0, len(crafter.constants.actions) - 1)) # type: ignore
212
+
213
+ # ────────────────────────────────────────────────────────────────────────
214
+ # Core StatefulEngine API
215
+ # ────────────────────────────────────────────────────────────────────────
216
+
217
+ async def _reset_engine(
218
+ self, *, seed: Optional[int] | None = None
219
+ ) -> Tuple[CrafterPrivateState, CrafterPublicState]:
220
+ if seed is not None:
221
+ # Re‑instantiate env with new seed to match crafter's internal reseeding convention
222
+ self.env = crafter.Env(area=self.env._area, length=self.env._length, seed=seed)
223
+ obs_img = self.env.reset()
224
+ self._total_reward = 0.0
225
+ pub = self._build_public_state(obs_img)
226
+ priv = self._build_private_state(reward=0.0, terminated=False, truncated=False)
227
+
228
+ # Player starting position tracked internally
229
+
230
+ return priv, pub
231
+
232
+ async def _step_engine(self, action: int) -> Tuple[CrafterPrivateState, CrafterPublicState]:
233
+ step_start_time = time.time()
234
+ try:
235
+ # Validate action is in valid range
236
+ if action < 0 or action >= self.env.action_space.n:
237
+ raise ValueError(
238
+ f"Invalid action {action}, must be in range [0, {self.env.action_space.n})"
239
+ )
240
+
241
+ # Ensure player reference is valid before proceeding
242
+ if self.env._player is None:
243
+ # Try to find player in world objects
244
+ for obj in self.env._world._objects:
245
+ if (
246
+ obj is not None
247
+ and hasattr(obj, "__class__")
248
+ and obj.__class__.__name__ == "Player"
249
+ ):
250
+ self.env._player = obj
251
+ break
252
+
253
+ if self.env._player is None:
254
+ raise RuntimeError("Player object not found in world")
255
+
256
+ # Build public state BEFORE step (baseline if needed)
257
+ pub_state_before = self._build_public_state(self.env.render())
258
+
259
+ # Step the environment
260
+ crafter_step_start = time.time()
261
+ obs, reward, done, info = self.env.step(action)
262
+ crafter_step_time = time.time() - crafter_step_start
263
+ logger.debug(f"Crafter env.step() took {crafter_step_time:.3f}s")
264
+
265
+ # Update internal state
266
+ self.obs = obs
267
+ self.done = done
268
+ self.info = info
269
+ self.last_reward = reward
270
+
271
+ # Step count is tracked by the crafter environment itself in self.env._step
272
+
273
+ # Process achievements - check what was unlocked this step
274
+ new_achievements = set()
275
+ if "achievements" in info:
276
+ for achievement, status in info["achievements"].items():
277
+ if status and achievement not in self.achievements_unlocked:
278
+ new_achievements.add(achievement)
279
+ self.achievements_unlocked.add(achievement)
280
+
281
+ # Calculate reward
282
+ reward_from_stack = 0
283
+ try:
284
+ if hasattr(self, "_reward_stack") and self._reward_stack:
285
+ reward_from_stack = sum(self._reward_stack)
286
+ self._reward_stack.clear()
287
+ except Exception as e:
288
+ reward_from_stack = 0
289
+
290
+ # Create private state
291
+ # Current episode reward
292
+ final_reward = self._total_reward + reward + reward_from_stack
293
+ self._total_reward = final_reward
294
+
295
+ # Determine proper termination reason based on game state
296
+ player = self.env._player # type: ignore[attr-defined]
297
+ current_step = self.env._step # type: ignore[attr-defined]
298
+ max_steps = self.env._length # type: ignore[attr-defined]
299
+
300
+ # Check if player died (health <= 0)
301
+ player_died = player.health <= 0
302
+
303
+ # Check if max steps reached
304
+ max_steps_reached = current_step >= max_steps
305
+
306
+ # Set termination flags properly:
307
+ # - terminated=True only if player actually died
308
+ # - truncated=True only if episode ended due to step limit
309
+ if done:
310
+ if player_died:
311
+ terminated = True
312
+ truncated = False
313
+ elif max_steps_reached:
314
+ terminated = False
315
+ truncated = True
316
+ else:
317
+ # Fallback: if done=True but unclear reason, assume timeout
318
+ terminated = False
319
+ truncated = True
320
+ else:
321
+ terminated = False
322
+ truncated = False
323
+
324
+ final_priv_state = self._build_private_state(final_reward, terminated, truncated)
325
+
326
+ # Build public state AFTER step to reflect latest world and achievements
327
+ pub_state_after = self._build_public_state(obs, info)
328
+
329
+ # Store post-step state as baseline for next step
330
+ self._previous_public_state_for_reward = pub_state_after
331
+ self._previous_private_state_for_reward = final_priv_state
332
+
333
+ total_step_time = time.time() - step_start_time
334
+ logger.debug(
335
+ f"CrafterEngine _step_engine took {total_step_time:.3f}s (crafter.step: {crafter_step_time:.3f}s)"
336
+ )
337
+ return final_priv_state, pub_state_after
338
+
339
+ except Exception as e:
340
+ # Create error state
341
+ import traceback
342
+
343
+ logger.error(f"Step engine error: {e}")
344
+ logger.error(traceback.format_exc())
345
+ error_pub_state = self._get_public_state_from_env()
346
+ error_pub_state.error_info = f"Step engine error: {e}"
347
+ error_priv_state = self._get_private_state_from_env(
348
+ reward=-1.0, terminated=True, truncated=False
349
+ )
350
+ return error_priv_state, error_pub_state
351
+
352
+ # ------------------------------------------------------------------
353
+ # Rendering (simple text summary)
354
+ # ------------------------------------------------------------------
355
+
356
+ async def _render(
357
+ self,
358
+ private_state: CrafterPrivateState,
359
+ public_state: CrafterPublicState,
360
+ get_observation: Optional[GetObservationCallable] = None,
361
+ ) -> str: # type: ignore[override]
362
+ obs_cb = get_observation or CrafterObservationCallable()
363
+ obs = await obs_cb.get_observation(public_state, private_state)
364
+ if isinstance(obs, str):
365
+ return obs
366
+ if isinstance(obs, dict):
367
+ header = f"steps: {public_state.num_steps_taken}/{public_state.max_steps_episode} | "
368
+ header += f"last_r: {private_state.reward_last_step:.2f} | total_r: {private_state.total_reward_episode:.2f}"
369
+ inv = ", ".join(f"{k}:{v}" for k, v in public_state.inventory.items() if v)
370
+ ach = ", ".join(k for k, v in public_state.achievements_status.items() if v)
371
+ return f"{header}\ninv: {inv}\nach: {ach}"
372
+ return str(obs)
373
+
374
+ # ------------------------------------------------------------------
375
+ # Snapshotting for exact reproducibility
376
+ # ------------------------------------------------------------------
377
+
378
+ async def _serialize_engine(self) -> CrafterEngineSnapshot:
379
+ world = self.env._world # type: ignore[attr-defined]
380
+ objects_state = [None if o is None else serialize_world_object(o) for o in world._objects]
381
+ # capture total reward and original seed
382
+ total_reward = self._total_reward
383
+ snap = CrafterEngineSnapshot(
384
+ env_raw_state=self.env.save(),
385
+ total_reward_snapshot=total_reward,
386
+ crafter_seed=self.env._seed,
387
+ previous_public_state_snapshot=dataclasses.asdict(
388
+ self._previous_public_state_for_reward
389
+ )
390
+ if self._previous_public_state_for_reward
391
+ else None,
392
+ previous_private_state_snapshot=dataclasses.asdict(
393
+ self._previous_private_state_for_reward
394
+ )
395
+ if self._previous_private_state_for_reward
396
+ else None,
397
+ )
398
+ return snap
399
+
400
+ @classmethod
401
+ async def _deserialize_engine(
402
+ cls, snapshot: CrafterEngineSnapshot, task_instance: TaskInstance
403
+ ) -> "CrafterEngine":
404
+ engine = cls(task_instance)
405
+ # Initialize env first to create structures
406
+ obs = engine.env.reset()
407
+ # Then load the saved state (this overrides the reset)
408
+ engine.env.load(snapshot.env_raw_state)
409
+ engine._total_reward = snapshot.total_reward_snapshot
410
+ engine.env._seed = snapshot.crafter_seed
411
+
412
+ # Initialize engine state attributes that step() expects
413
+ engine.obs = engine.env.render()
414
+ engine.done = False
415
+ engine.info = {}
416
+ engine.last_reward = 0.0
417
+
418
+ # Ensure achievements tracking is initialized
419
+ engine.achievements_unlocked = set()
420
+
421
+ # Re-establish previous states for reward system continuity if first step after load
422
+ engine._previous_public_state_for_reward = engine._build_public_state(engine.env.render())
423
+ # Safe comparisons to avoid string vs int errors
424
+ health_dead = safe_compare(0, engine.env._player.health, ">=")
425
+ step_exceeded = safe_compare(engine.env._length, engine.env._step, "<=")
426
+ engine._previous_private_state_for_reward = engine._build_private_state(
427
+ 0.0, health_dead, step_exceeded
428
+ )
429
+ return engine
430
+
431
+ # ------------------------------------------------------------------
432
+ # Internal helpers
433
+ # ------------------------------------------------------------------
434
+
435
+ def _build_public_state(
436
+ self, obs_img: np.ndarray, info: Optional[Dict[str, Any]] | None = None
437
+ ) -> CrafterPublicState:
438
+ try:
439
+ if info is None:
440
+ player = self.env._player # type: ignore[attr-defined]
441
+ # Safe achievement status check
442
+ achievements_status = {}
443
+ for k, v in player.achievements.items():
444
+ achievements_status[k] = safe_compare(0, v, "<")
445
+ inventory = player.inventory.copy()
446
+ semantic = getattr(self.env, "_sem_view", lambda: None)()
447
+ else:
448
+ inventory = info.get("inventory", {})
449
+ # Safe achievement status check from info
450
+ achievements_status = {}
451
+ achievements_info = info.get("achievements", {})
452
+ for k, v in achievements_info.items():
453
+ achievements_status[k] = safe_compare(0, v, "<")
454
+ semantic = info.get("semantic")
455
+
456
+ player = self.env._player # type: ignore[attr-defined]
457
+ return CrafterPublicState(
458
+ inventory=inventory,
459
+ achievements_status=achievements_status,
460
+ player_position=tuple(player.pos), # type: ignore[attr-defined]
461
+ player_direction=player.facing, # type: ignore[attr-defined]
462
+ semantic_map=semantic,
463
+ world_material_map=self.env._world._mat_map.copy(), # type: ignore[attr-defined]
464
+ observation_image=obs_img,
465
+ num_steps_taken=self.env._step, # type: ignore[attr-defined]
466
+ max_steps_episode=self.env._length, # type: ignore[attr-defined]
467
+ error_info=info.get("error_info") if info else None,
468
+ )
469
+ except Exception as e:
470
+ logging.error(f"Error building public state: {e}")
471
+ # Return minimal safe state
472
+ return CrafterPublicState(
473
+ inventory={},
474
+ achievements_status={},
475
+ player_position=(0, 0),
476
+ player_direction=0,
477
+ semantic_map=None,
478
+ world_material_map=np.zeros((1, 1), dtype=np.uint8),
479
+ observation_image=obs_img
480
+ if obs_img is not None
481
+ else np.zeros((64, 64, 3), dtype=np.uint8),
482
+ num_steps_taken=0,
483
+ max_steps_episode=10000,
484
+ error_info=f"State building error: {e}",
485
+ )
486
+
487
+ def _build_private_state(
488
+ self, reward: float, terminated: bool, truncated: bool
489
+ ) -> CrafterPrivateState:
490
+ player = self.env._player # type: ignore[attr-defined]
491
+ stats = {
492
+ "health": player.health,
493
+ "food": player.inventory.get("food"),
494
+ "drink": player.inventory.get("drink"),
495
+ "energy": player.inventory.get("energy"),
496
+ "_hunger": getattr(player, "_hunger", 0),
497
+ "_thirst": getattr(player, "_thirst", 0),
498
+ }
499
+ return CrafterPrivateState(
500
+ reward_last_step=reward,
501
+ total_reward_episode=self._total_reward,
502
+ achievements_current_values=player.achievements.copy(),
503
+ terminated=terminated,
504
+ truncated=truncated,
505
+ player_internal_stats=stats,
506
+ world_rng_state_snapshot=self.env._world.random.get_state(), # type: ignore[attr-defined]
507
+ )
508
+
509
+ def _get_public_state_from_env(self) -> CrafterPublicState:
510
+ """Helper method to get current public state from synth_ai.environments.environment"""
511
+ try:
512
+ obs_img = self.env.render()
513
+ return self._build_public_state(obs_img)
514
+ except Exception as e:
515
+ logging.error(f"Error getting public state from env: {e}")
516
+ # Return default state
517
+ return CrafterPublicState(
518
+ inventory={},
519
+ achievements_status={},
520
+ player_position=(0, 0),
521
+ player_direction=0,
522
+ semantic_map=None,
523
+ world_material_map=np.zeros((1, 1), dtype=np.uint8),
524
+ observation_image=np.zeros((64, 64, 3), dtype=np.uint8),
525
+ num_steps_taken=0,
526
+ max_steps_episode=10000,
527
+ error_info=f"State extraction error: {e}",
528
+ )
529
+
530
+ def _get_private_state_from_env(
531
+ self, reward: float, terminated: bool, truncated: bool
532
+ ) -> CrafterPrivateState:
533
+ """Helper method to get current private state from synth_ai.environments.environment"""
534
+ try:
535
+ return self._build_private_state(reward, terminated, truncated)
536
+ except Exception as e:
537
+ logging.error(f"Error getting private state from env: {e}")
538
+ # Return default state
539
+ return CrafterPrivateState(
540
+ reward_last_step=reward,
541
+ total_reward_episode=0.0,
542
+ achievements_current_values={},
543
+ terminated=terminated,
544
+ truncated=truncated,
545
+ player_internal_stats={},
546
+ world_rng_state_snapshot=None,
547
+ )
548
+
549
+
550
+ # --- Reward Components ---
551
+ class CrafterAchievementComponent(RewardComponent):
552
+ async def score(self, state: CrafterPublicState, action: Dict[str, Any]) -> float:
553
+ prev_achievements = action.get("previous_public_state_achievements", {})
554
+ current_achievements = state.achievements_status
555
+ new_achievements = sum(
556
+ 1
557
+ for ach, status in current_achievements.items()
558
+ if status and not prev_achievements.get(ach)
559
+ )
560
+ return float(new_achievements) * 0.1
561
+
562
+
563
+ class CrafterPlayerStatComponent(RewardComponent):
564
+ async def score(self, state: CrafterPrivateState, action: Dict[str, Any]) -> float:
565
+ current_health = state.player_internal_stats.get("health", 0)
566
+ prev_health = action.get("previous_private_state_stats", {}).get("health", current_health)
567
+ if current_health < prev_health:
568
+ return -0.05 # Lost health penalty
569
+ return 0.0
570
+
571
+
572
+ class CrafterStepPenaltyComponent(RewardComponent):
573
+ def __init__(self, penalty: float = -0.001):
574
+ super().__init__()
575
+ self.penalty = penalty
576
+ self.weight = 1.0
577
+
578
+ async def score(self, state: Any, action: Any) -> float:
579
+ return self.penalty
@@ -0,0 +1,63 @@
1
+ """
2
+ Apply once (import this module anywhere before CrafterEngine is used).
3
+ It replaces Env._balance_object so that every per-chunk object list is
4
+ sorted by (x, y, class-name) before any random choice is made – removing
5
+ the hash-based set-iteration nondeterminism that caused the drift.
6
+ """
7
+
8
+ import collections
9
+ import crafter
10
+
11
+ print("[PATCH] Attempting to apply Crafter deterministic patch...")
12
+
13
+ # -----------------------------------------------------------------------------
14
+ # 1. Make per–chunk object order stable
15
+ # -----------------------------------------------------------------------------
16
+ if not hasattr(crafter.Env, "_orig_balance_object"):
17
+ print("[PATCH] Patching crafter.Env._balance_object...")
18
+ crafter.Env._orig_balance_object = crafter.Env._balance_object
19
+
20
+ def _balance_object_det(self, chunk, objs, *args, **kwargs):
21
+ # cls, material, span_dist, despan_dist, spawn_prob, despawn_prob, ctor, target_fn
22
+ # were part of the original signature, but *args, **kwargs is more robust.
23
+ objs = sorted(objs, key=lambda o: (o.pos[0], o.pos[1], o.__class__.__name__))
24
+ return crafter.Env._orig_balance_object(self, chunk, objs, *args, **kwargs)
25
+
26
+ crafter.Env._balance_object = _balance_object_det
27
+ print("[PATCH] crafter.Env._balance_object patched.")
28
+ else:
29
+ print("[PATCH] crafter.Env._balance_object already patched or _orig_balance_object exists.")
30
+
31
+ # -----------------------------------------------------------------------------
32
+ # 2. Make *chunk* iteration order stable
33
+ # -----------------------------------------------------------------------------
34
+ if not hasattr(crafter.engine.World, "_orig_chunks_prop"):
35
+ crafter.engine.World._orig_chunks_prop = crafter.engine.World.chunks
36
+
37
+ def _chunks_sorted(self):
38
+ # OrderedDict keeps the sorted key order during iteration
39
+ return collections.OrderedDict(sorted(self._chunks.items()))
40
+
41
+ crafter.engine.World.chunks = property(_chunks_sorted)
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # 3. NEW: keep per-frame object update order deterministic
45
+ # -----------------------------------------------------------------------------
46
+ if not hasattr(crafter.engine.World, "_orig_objects_prop"):
47
+ crafter.engine.World._orig_objects_prop = crafter.engine.World.objects # save original
48
+
49
+ @property
50
+ def _objects_sorted(self):
51
+ objs = [o for o in self._objects if o] # Filter out None (removed) objects
52
+ # stable order: x, y, class-name, creation-index
53
+ return sorted(
54
+ objs,
55
+ key=lambda o: (
56
+ o.pos[0],
57
+ o.pos[1],
58
+ o.__class__.__name__,
59
+ getattr(o, "_id", 0),
60
+ ),
61
+ )
62
+
63
+ crafter.engine.World.objects = _objects_sorted
@@ -0,0 +1,5 @@
1
+ import crafter.constants as C
2
+ from typing import Dict
3
+
4
+ # Map each action name to its corresponding index in the crafter package
5
+ CRAFTER_ACTION_MAP: Dict[str, int] = {action_name: idx for idx, action_name in enumerate(C.actions)}