synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +579 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/environments/examples/wordle/__init__.py +29 -0
  100. synth_ai/environments/examples/wordle/engine.py +391 -0
  101. synth_ai/environments/examples/wordle/environment.py +154 -0
  102. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  103. synth_ai/environments/examples/wordle/taskset.py +222 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/environments/service/core_routes.py +38 -0
  106. synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
  107. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
  108. synth_ai/learning/prompts/mipro.py +273 -1
  109. synth_ai/learning/prompts/random_search.py +247 -0
  110. synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
  111. synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
  112. synth_ai/lm/injection.py +81 -0
  113. synth_ai/lm/overrides.py +204 -0
  114. synth_ai/lm/provider_support/anthropic.py +39 -12
  115. synth_ai/lm/provider_support/openai.py +31 -4
  116. synth_ai/lm/vendors/core/anthropic_api.py +16 -0
  117. synth_ai/lm/vendors/openai_standard.py +35 -5
  118. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
  119. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
  120. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
  121. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
  122. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
  123. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ """CrafterCustomEnvironment — Custom Crafter with configurable world generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import List, Optional, Any, Dict, Union
6
+ import dataclasses
7
+ import logging
8
+ import time
9
+
10
+ # Import logging configuration to suppress JAX debug messages
11
+ from synth_ai.environments.examples.crafter_classic.config_logging import safe_compare
12
+
13
+ # Import tracing abstractions
14
+ from synth_ai.tracing_v3.abstractions import (
15
+ RuntimeEvent,
16
+ SessionEventMarkovBlanketMessage,
17
+ TimeRecord,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Import the base Crafter components
23
+ from synth_ai.environments.examples.crafter_classic.engine import (
24
+ CrafterEngine,
25
+ CrafterPrivateState,
26
+ CrafterPublicState,
27
+ CrafterEngineSnapshot,
28
+ )
29
+ from synth_ai.environments.examples.crafter_classic.taskset import CrafterTaskInstance
30
+ from synth_ai.environments.environment.shared_engine import (
31
+ GetObservationCallable,
32
+ InternalObservation,
33
+ )
34
+ from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
35
+ from synth_ai.environments.stateful.core import StatefulEnvironment
36
+ from synth_ai.environments.environment.tools import (
37
+ AbstractTool,
38
+ EnvToolCall,
39
+ ToolResult,
40
+ TOOL_REGISTRY,
41
+ register_tool,
42
+ )
43
+ from pydantic import BaseModel, Field
44
+
45
+
46
+ # Use the same tool and observation classes as CrafterClassic
47
+ from synth_ai.environments.examples.crafter_classic.environment import (
48
+ CrafterActionInput,
49
+ CrafterInteractTool,
50
+ SynthCrafterObservationCallable,
51
+ )
52
+
53
+
54
+ class CrafterCustomEnvironment(StatefulEnvironment, ReproducibleEnvironment[CrafterEngine]):
55
+ """Custom Crafter environment with configurable world generation."""
56
+
57
+ def __init__(
58
+ self,
59
+ task_instance: "CrafterTaskInstance",
60
+ custom_step_obs: Optional[GetObservationCallable] = None,
61
+ custom_ckpt_obs: Optional[GetObservationCallable] = None,
62
+ session_tracer: Optional[Any] = None, # SessionTracer from higher level
63
+ ) -> None:
64
+ self.name = "CrafterCustom"
65
+ self.task_instance = task_instance
66
+ self.custom_step_observation_callable = custom_step_obs or SynthCrafterObservationCallable()
67
+ self.custom_checkpoint_observation_callable = (
68
+ custom_ckpt_obs or SynthCrafterObservationCallable()
69
+ )
70
+
71
+ # Ensure task instance has world configuration
72
+ if hasattr(task_instance, "metadata"):
73
+ logger.info(
74
+ f"Creating CrafterCustom with world_config: {getattr(task_instance.metadata, 'world_config', 'default')}"
75
+ )
76
+
77
+ self.engine = CrafterEngine(task_instance)
78
+ self.session_tracer = session_tracer # Store tracer for runtime events
79
+
80
+ self._interact_tool = CrafterInteractTool(self.engine, session_tracer=session_tracer)
81
+
82
+ # Register tool with a unique name for this environment
83
+ tool_name = f"{self.name.lower()}_interact"
84
+ if tool_name not in TOOL_REGISTRY:
85
+ # Create a copy of the tool with the custom name
86
+ self._interact_tool.name = tool_name
87
+ register_tool(self._interact_tool)
88
+
89
+ # ────────────────────────────────────────────────────────────────────
90
+ # Lifecycle helpers
91
+ # ────────────────────────────────────────────────────────────────────
92
+
93
+ async def initialize(self) -> InternalObservation: # type: ignore[override]
94
+ priv, pub = await self.engine._reset_engine()
95
+ return await self._to_observation(priv, pub, self.custom_step_observation_callable)
96
+
97
+ async def terminate(self) -> InternalObservation: # type: ignore[override]
98
+ pub = self.engine._get_public_state_from_env()
99
+ priv = self.engine._get_private_state_from_env(0, True, False) # Terminated state
100
+ priv.terminated = True
101
+ obs_dict = {"status": "Environment terminated."}
102
+ return await self._to_observation(
103
+ priv, pub, self.custom_step_observation_callable, extra_obs=obs_dict
104
+ )
105
+
106
+ # ────────────────────────────────────────────────────────────────────
107
+ # Step + checkpoint
108
+ # ────────────────────────────────────────────────────────────────────
109
+
110
+ def validate_tool_calls(
111
+ self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
112
+ ) -> EnvToolCall:
113
+ # Store the original tool calls for tracing
114
+ state_before = {"tool_calls": tool_calls}
115
+
116
+ # Normalize and validate to a single EnvToolCall
117
+ if isinstance(tool_calls, list):
118
+ if not tool_calls:
119
+ raise ValueError("Received empty list of tool calls.")
120
+ if isinstance(tool_calls[0], list):
121
+ if not tool_calls[0]:
122
+ raise ValueError("Received empty inner list of tool calls.")
123
+ agent_call = tool_calls[0][0]
124
+ else:
125
+ agent_call = tool_calls[0]
126
+ elif isinstance(tool_calls, EnvToolCall):
127
+ agent_call = tool_calls
128
+ else:
129
+ raise TypeError(f"Unexpected type for tool_calls: {type(tool_calls)}")
130
+
131
+ if not isinstance(agent_call, EnvToolCall):
132
+ raise TypeError(f"Processed call is not EnvToolCall: {type(agent_call)}")
133
+
134
+ # Accept both "interact" and "craftercustom_interact"
135
+ if agent_call.tool not in ["interact", f"{self.name.lower()}_interact"]:
136
+ raise ValueError(
137
+ f"Unknown tool: {agent_call.tool}. Expected 'interact' or '{self.name.lower()}_interact'."
138
+ )
139
+
140
+ # Record runtime event for tool call validation
141
+ if (
142
+ self.session_tracer
143
+ and hasattr(self.session_tracer, "current_session")
144
+ and self.session_tracer.current_session
145
+ ):
146
+ runtime_validation_event = RuntimeEvent()
147
+ runtime_validation_event.time_record = TimeRecord()
148
+ runtime_validation_event.time_record.event_time = time.time()
149
+ runtime_validation_event.time_record.message_time = None
150
+ runtime_validation_event.system_instance_id = "crafter_custom_environment"
151
+ runtime_validation_event.system_state_before = state_before
152
+ runtime_validation_event.system_state_after = {"validated_call": agent_call}
153
+ runtime_validation_event.metadata = {"validation_step": "tool_call_validation"}
154
+ # Add directly to event history, bypassing timestep requirement
155
+ self.session_tracer.current_session.add_event(runtime_validation_event)
156
+
157
+ return agent_call
158
+
159
+ async def step(
160
+ self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
161
+ ) -> InternalObservation: # type: ignore[override]
162
+ agent_call = self.validate_tool_calls(tool_calls)
163
+ tool_result: ToolResult = await self._interact_tool(agent_call)
164
+
165
+ payload_dict = tool_result.payload
166
+ pub_state: CrafterPublicState
167
+ priv_state: CrafterPrivateState
168
+
169
+ if tool_result.ok:
170
+ # payload contains the actual state objects from the interact tool
171
+ priv_state = payload_dict.get("private_state")
172
+ pub_state = payload_dict.get("public_state")
173
+
174
+ # Validate we got the expected state objects
175
+ if not isinstance(priv_state, CrafterPrivateState) or not isinstance(
176
+ pub_state, CrafterPublicState
177
+ ):
178
+ logger.error(
179
+ f"Invalid state types in payload: priv={type(priv_state)}, pub={type(pub_state)}"
180
+ )
181
+ # Fall back to getting current state
182
+ pub_state = self.engine._get_public_state_from_env()
183
+ health_dead = safe_compare(0, self.engine.env._player.health, ">=")
184
+ step_exceeded = safe_compare(self.engine.env._length, self.engine.env._step, "<=")
185
+ priv_state = self.engine._get_private_state_from_env(0, health_dead, step_exceeded)
186
+ pub_state.error_info = "Invalid state types in tool result"
187
+ else:
188
+ # Tool call failed, use states from payload if available, otherwise get current state
189
+ priv_state = payload_dict.get("private_state")
190
+ pub_state = payload_dict.get("public_state")
191
+
192
+ if not isinstance(priv_state, CrafterPrivateState) or not isinstance(
193
+ pub_state, CrafterPublicState
194
+ ):
195
+ # Fall back to getting current state
196
+ pub_state = self.engine._get_public_state_from_env()
197
+ health_dead = safe_compare(0, self.engine.env._player.health, ">=")
198
+ step_exceeded = safe_compare(self.engine.env._length, self.engine.env._step, "<=")
199
+ priv_state = self.engine._get_private_state_from_env(0, health_dead, step_exceeded)
200
+
201
+ if tool_result.error:
202
+ pub_state.error_info = tool_result.error
203
+
204
+ return await self._to_observation(
205
+ priv_state, pub_state, self.custom_step_observation_callable
206
+ )
207
+
208
+ async def checkpoint(self) -> InternalObservation: # type: ignore[override]
209
+ engine_snapshot: CrafterEngineSnapshot = await self.engine._serialize_engine()
210
+ priv = self.engine._get_private_state_from_env(0, False, False) # Get current state for obs
211
+ pub = self.engine._get_public_state_from_env()
212
+ obs_data = await self._to_observation(
213
+ priv, pub, self.custom_checkpoint_observation_callable
214
+ )
215
+ if isinstance(obs_data, dict):
216
+ obs_data["engine_snapshot_data"] = engine_snapshot.model_dump()
217
+ return obs_data
218
+
219
+ async def get_metadata(self) -> Dict[str, Any]:
220
+ """Get metadata about the current environment configuration."""
221
+ metadata = {
222
+ "environment_type": "CrafterCustom",
223
+ "engine_seed": getattr(self.engine.env, "_seed", None),
224
+ "world_area": self.engine.env._area,
225
+ "max_steps": self.engine.env._length,
226
+ "current_step": self.engine.env._step,
227
+ }
228
+
229
+ # Add task instance metadata
230
+ if hasattr(self.task_instance, "metadata"):
231
+ task_metadata = self.task_instance.metadata
232
+ metadata.update(
233
+ {
234
+ "difficulty": getattr(task_metadata, "difficulty", None),
235
+ "world_config": getattr(task_metadata, "world_config", None),
236
+ "world_config_path": getattr(task_metadata, "world_config_path", None),
237
+ "num_trees_radius": getattr(task_metadata, "num_trees_radius", None),
238
+ "num_cows_radius": getattr(task_metadata, "num_cows_radius", None),
239
+ "num_hostiles_radius": getattr(task_metadata, "num_hostiles_radius", None),
240
+ }
241
+ )
242
+
243
+ # Add current world statistics
244
+ if hasattr(self.engine, "env") and hasattr(self.engine.env, "_world"):
245
+ world = self.engine.env._world
246
+ object_counts = {}
247
+
248
+ for obj in world._objects:
249
+ if obj is None:
250
+ continue
251
+ obj_type = type(obj).__name__
252
+ object_counts[obj_type] = object_counts.get(obj_type, 0) + 1
253
+
254
+ metadata["world_object_counts"] = object_counts
255
+
256
+ return metadata
257
+
258
+ # ────────────────────────────────────────────────────────────────────
259
+ # Helpers
260
+ # ────────────────────────────────────────────────────────────────────
261
+
262
+ async def _to_observation(
263
+ self,
264
+ priv: CrafterPrivateState,
265
+ pub: CrafterPublicState,
266
+ obs_cb: Optional[GetObservationCallable],
267
+ extra_obs: Optional[Dict[str, Any]] = None,
268
+ ) -> InternalObservation:
269
+ # Store state before observation generation
270
+ state_before = {"private_state": priv, "public_state": pub}
271
+
272
+ active_obs_cb = obs_cb or SynthCrafterObservationCallable()
273
+ observation = await active_obs_cb.get_observation(pub, priv)
274
+ if extra_obs and isinstance(observation, dict):
275
+ observation.update(extra_obs)
276
+
277
+ # Record runtime event for observation generation
278
+ if (
279
+ self.session_tracer
280
+ and hasattr(self.session_tracer, "current_session")
281
+ and self.session_tracer.current_session
282
+ ):
283
+ runtime_obs_event = RuntimeEvent()
284
+ runtime_obs_event.time_record = TimeRecord()
285
+ runtime_obs_event.time_record.event_time = time.time()
286
+ runtime_obs_event.time_record.message_time = None
287
+ runtime_obs_event.system_instance_id = "observation_generator"
288
+ runtime_obs_event.system_state_before = state_before
289
+ runtime_obs_event.system_state_after = {"observation": observation}
290
+ runtime_obs_event.metadata = {"observation_step": "state_to_obs_conversion"}
291
+ # Add directly to event history, bypassing timestep requirement
292
+ self.session_tracer.current_session.add_event(runtime_obs_event)
293
+
294
+ return observation
295
+
296
+ # ────────────────────────────────────────────────────────────────────
297
+ # ReproducibleEnvironment plumbing
298
+ # ────────────────────────────────────────────────────────────────────
299
+
300
+ async def _serialize_engine(self) -> CrafterEngineSnapshot:
301
+ return await self.engine._serialize_engine()
302
+
303
+ @classmethod
304
+ async def _deserialize_engine(
305
+ cls, snapshot: CrafterEngineSnapshot, task_instance: "CrafterTaskInstance"
306
+ ) -> "CrafterCustomEnvironment":
307
+ eng = await CrafterEngine._deserialize_engine(snapshot, task_instance)
308
+ env = cls(task_instance)
309
+ env.engine = eng
310
+ # CRITICAL: Update the interact tool to use the new engine!
311
+ env._interact_tool.engine = eng
312
+ return env
@@ -0,0 +1,305 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Run script for Crafter dataset instances
4
+ """
5
+
6
+ import json
7
+ import argparse
8
+ import random
9
+ from pathlib import Path
10
+ from typing import List, Optional, Dict, Any
11
+ import uuid
12
+ import os
13
+ import sys
14
+
15
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
16
+
17
+ from crafter import Env
18
+
19
+
20
+ class CrafterDatasetRunner:
21
+ """Run Crafter instances from a dataset"""
22
+
23
+ def __init__(self, dataset_path: Path = Path("dataset")):
24
+ self.dataset_path = dataset_path
25
+
26
+ def load_dataset(self, dataset_name: str) -> Dict[str, Any]:
27
+ """Load a dataset from disk"""
28
+ dataset_dir = self.dataset_path / dataset_name
29
+
30
+ # Load metadata
31
+ with open(dataset_dir / "metadata.json", "r") as f:
32
+ metadata = json.load(f)
33
+
34
+ # Load instances
35
+ with open(dataset_dir / "instances.json", "r") as f:
36
+ instances = json.load(f)
37
+
38
+ return {"metadata": metadata, "instances": instances}
39
+
40
+ def filter_instances(
41
+ self,
42
+ instances: List[Dict[str, Any]],
43
+ difficulties: Optional[List[str]] = None,
44
+ impetus_types: Optional[List[str]] = None,
45
+ split: Optional[str] = None,
46
+ split_info: Optional[Dict[str, Any]] = None,
47
+ ) -> List[Dict[str, Any]]:
48
+ """Filter instances based on criteria"""
49
+ filtered = instances
50
+
51
+ # Filter by difficulty
52
+ if difficulties:
53
+ filtered = [inst for inst in filtered if inst["metadata"]["difficulty"] in difficulties]
54
+
55
+ # Filter by impetus type
56
+ if impetus_types:
57
+ filtered = [inst for inst in filtered if self._get_impetus_type(inst) in impetus_types]
58
+
59
+ # Filter by split
60
+ if split and split_info:
61
+ if split == "train":
62
+ val_ids = set(split_info["val_instance_ids"])
63
+ test_ids = set(split_info["test_instance_ids"])
64
+ filtered = [
65
+ inst
66
+ for inst in filtered
67
+ if inst["id"] not in val_ids and inst["id"] not in test_ids
68
+ ]
69
+ elif split == "val":
70
+ val_ids = set(split_info["val_instance_ids"])
71
+ filtered = [inst for inst in filtered if inst["id"] in val_ids]
72
+ elif split == "test":
73
+ test_ids = set(split_info["test_instance_ids"])
74
+ filtered = [inst for inst in filtered if inst["id"] in test_ids]
75
+
76
+ return filtered
77
+
78
+ def _get_impetus_type(self, instance: Dict[str, Any]) -> str:
79
+ """Determine impetus type from instructions"""
80
+ instructions = instance["impetus"]["instructions"].lower()
81
+ if "speedrun" in instructions:
82
+ return "speedrun"
83
+ elif "focus on" in instructions:
84
+ return "focused"
85
+ else:
86
+ return "general"
87
+
88
+ def run_instance(
89
+ self, instance: Dict[str, Any], render: bool = False, max_steps: int = 1000, agent_fn=None
90
+ ):
91
+ """Run a single instance"""
92
+
93
+ # Extract parameters
94
+ difficulty = instance["metadata"]["difficulty"]
95
+ seed = instance["metadata"]["world_seed"]
96
+
97
+ print(f"\n{'=' * 60}")
98
+ print(f"Running instance: {instance['id']}")
99
+ print(f"Difficulty: {difficulty}")
100
+ print(f"Seed: {seed}")
101
+ print(f"Instructions: {instance['impetus']['instructions']}")
102
+ if instance["impetus"].get("achievement_focus"):
103
+ print(f"Focus: {', '.join(instance['impetus']['achievement_focus'])}")
104
+ print(f"{'=' * 60}")
105
+
106
+ # Create environment
107
+ env = Env(seed=seed, world_config=difficulty)
108
+
109
+ obs = env.reset()
110
+
111
+ # Run agent or random policy
112
+ total_reward = 0
113
+ achievements = set()
114
+
115
+ for step in range(max_steps):
116
+ if agent_fn:
117
+ action = agent_fn(obs, instance)
118
+ else:
119
+ action = env.action_space.sample()
120
+
121
+ obs, reward, done, info = env.step(action)
122
+ total_reward += reward
123
+
124
+ # Track achievements
125
+ if "achievements" in info:
126
+ for ach, unlocked in info["achievements"].items():
127
+ if unlocked:
128
+ achievements.add(ach)
129
+
130
+ if done:
131
+ break
132
+
133
+ # Evaluate based on intent
134
+ success = self._evaluate_instance(instance, achievements, total_reward, step)
135
+
136
+ print(f"\nResults:")
137
+ print(f"Steps: {step}")
138
+ print(f"Total reward: {total_reward}")
139
+ print(f"Achievements: {len(achievements)} - {list(achievements)}")
140
+ print(f"Success: {success}")
141
+
142
+ return {
143
+ "instance_id": instance["id"],
144
+ "difficulty": difficulty,
145
+ "seed": seed,
146
+ "steps": step,
147
+ "total_reward": total_reward,
148
+ "achievements": list(achievements),
149
+ "success": success,
150
+ }
151
+
152
+ def _evaluate_instance(
153
+ self, instance: Dict[str, Any], achievements: set, total_reward: float, steps: int
154
+ ) -> bool:
155
+ """Evaluate if instance was successful based on intent"""
156
+ intent = instance["intent"]
157
+
158
+ # Check minimum score
159
+ if intent.get("minimum_score"):
160
+ if len(achievements) < intent["minimum_score"]:
161
+ return False
162
+
163
+ # Check target achievements
164
+ if intent.get("target_achievements"):
165
+ targets = set(intent["target_achievements"])
166
+ if not achievements.intersection(targets):
167
+ return False
168
+
169
+ return True
170
+
171
+ def run_batch(
172
+ self,
173
+ dataset_name: str,
174
+ num_instances: int = 10,
175
+ difficulties: Optional[List[str]] = None,
176
+ impetus_types: Optional[List[str]] = None,
177
+ split: Optional[str] = None,
178
+ render: bool = False,
179
+ max_steps: int = 1000,
180
+ agent_fn=None,
181
+ ):
182
+ """Run a batch of instances"""
183
+
184
+ # Load dataset
185
+ dataset = self.load_dataset(dataset_name)
186
+ instances = dataset["instances"]
187
+
188
+ # Filter instances
189
+ filtered = self.filter_instances(
190
+ instances,
191
+ difficulties=difficulties,
192
+ impetus_types=impetus_types,
193
+ split=split,
194
+ split_info=dataset["metadata"].get("split_info"),
195
+ )
196
+
197
+ if not filtered:
198
+ print("No instances match the filter criteria!")
199
+ return []
200
+
201
+ # Sample instances
202
+ if num_instances > len(filtered):
203
+ print(f"Only {len(filtered)} instances available, running all")
204
+ selected = filtered
205
+ else:
206
+ selected = random.sample(filtered, num_instances)
207
+
208
+ print(f"\nRunning {len(selected)} instances from {dataset_name}")
209
+ print(f"Difficulties: {difficulties or 'all'}")
210
+ print(f"Impetus types: {impetus_types or 'all'}")
211
+ print(f"Split: {split or 'all'}")
212
+
213
+ # Run instances
214
+ results = []
215
+ for instance in selected:
216
+ result = self.run_instance(
217
+ instance, render=render, max_steps=max_steps, agent_fn=agent_fn
218
+ )
219
+ results.append(result)
220
+
221
+ # Summary statistics
222
+ self._print_summary(results)
223
+
224
+ return results
225
+
226
+ def _print_summary(self, results: List[Dict[str, Any]]):
227
+ """Print summary statistics"""
228
+ print(f"\n{'=' * 60}")
229
+ print("SUMMARY")
230
+ print(f"{'=' * 60}")
231
+
232
+ # Group by difficulty
233
+ by_difficulty = {}
234
+ for result in results:
235
+ diff = result["difficulty"]
236
+ if diff not in by_difficulty:
237
+ by_difficulty[diff] = []
238
+ by_difficulty[diff].append(result)
239
+
240
+ print(f"\nResults by difficulty:")
241
+ print(
242
+ f"{'Difficulty':<15} {'Count':<8} {'Success':<10} {'Avg Steps':<12} {'Avg Achievements'}"
243
+ )
244
+ print("-" * 60)
245
+
246
+ for diff in sorted(by_difficulty.keys()):
247
+ diff_results = by_difficulty[diff]
248
+ count = len(diff_results)
249
+ success_rate = sum(1 for r in diff_results if r["success"]) / count
250
+ avg_steps = sum(r["steps"] for r in diff_results) / count
251
+ avg_achievements = sum(len(r["achievements"]) for r in diff_results) / count
252
+
253
+ print(
254
+ f"{diff:<15} {count:<8} {success_rate:<10.1%} {avg_steps:<12.1f} {avg_achievements:.1f}"
255
+ )
256
+
257
+ # Overall stats
258
+ total_success = sum(1 for r in results if r["success"])
259
+ print(
260
+ f"\nOverall success rate: {total_success}/{len(results)} ({total_success / len(results):.1%})"
261
+ )
262
+
263
+
264
+ def main():
265
+ parser = argparse.ArgumentParser(description="Run Crafter dataset instances")
266
+ parser.add_argument("dataset", help="Dataset name")
267
+ parser.add_argument(
268
+ "-n", "--num-instances", type=int, default=10, help="Number of instances to run"
269
+ )
270
+ parser.add_argument(
271
+ "-d",
272
+ "--difficulties",
273
+ nargs="+",
274
+ choices=["easy", "normal", "hard", "peaceful", "resource_rich"],
275
+ help="Filter by difficulties",
276
+ )
277
+ parser.add_argument(
278
+ "-t",
279
+ "--impetus-types",
280
+ nargs="+",
281
+ choices=["general", "focused", "speedrun"],
282
+ help="Filter by impetus types",
283
+ )
284
+ parser.add_argument(
285
+ "-s", "--split", choices=["train", "val", "test"], help="Filter by dataset split"
286
+ )
287
+ parser.add_argument("--render", action="store_true", help="Render the environment")
288
+ parser.add_argument("--max-steps", type=int, default=1000, help="Maximum steps per episode")
289
+
290
+ args = parser.parse_args()
291
+
292
+ runner = CrafterDatasetRunner()
293
+ runner.run_batch(
294
+ dataset_name=args.dataset,
295
+ num_instances=args.num_instances,
296
+ difficulties=args.difficulties,
297
+ impetus_types=args.impetus_types,
298
+ split=args.split,
299
+ render=args.render,
300
+ max_steps=args.max_steps,
301
+ )
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()