synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +575 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
  100. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +104 -6
  101. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
  102. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
  103. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
  104. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,589 @@
1
+ """MiniGrid Engine implementation.
2
+
3
+ This module provides a wrapper around Gymnasium MiniGrid environments
4
+ with full state management and serialization capabilities.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, Optional, Tuple, Union, List
12
+
13
+ import gymnasium as gym
14
+ import numpy as np
15
+ from minigrid.core.constants import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
16
+
17
+ from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
18
+ from synth_ai.environments.reproducibility.core import IReproducibleEngine
19
+ from synth_ai.environments.environment.rewards.core import RewardComponent, RewardStack
20
+ from synth_ai.environments.environment.shared_engine import (
21
+ GetObservationCallable,
22
+ InternalObservation,
23
+ )
24
+ from synth_ai.environments.tasks.core import TaskInstance
25
+ from synth_ai.environments.examples.minigrid.environment_mapping import (
26
+ get_environment_from_seed,
27
+ get_difficulty_from_seed,
28
+ validate_environment_name,
29
+ )
30
+
31
+
32
+ @dataclass
33
+ class MiniGridPublicState:
34
+ """Public state of the MiniGrid environment."""
35
+
36
+ grid_array: np.ndarray # The grid as a numpy array
37
+ agent_pos: Tuple[int, int] # Agent position (x, y)
38
+ agent_dir: int # Agent direction (0=right, 1=down, 2=left, 3=up)
39
+ carrying: Optional[Dict[str, Any]] = None # Object being carried
40
+ step_count: int = 0
41
+ max_steps: int = 1000
42
+ mission: str = ""
43
+ terminated: bool = False
44
+
45
+ def diff(self, prev_state: "MiniGridPublicState") -> Dict[str, Any]:
46
+ """Track changes between states."""
47
+ differences = {}
48
+ if not np.array_equal(self.grid_array, prev_state.grid_array):
49
+ differences["grid_array"] = self.grid_array.tolist()
50
+ if self.agent_pos != prev_state.agent_pos:
51
+ differences["agent_pos"] = self.agent_pos
52
+ if self.agent_dir != prev_state.agent_dir:
53
+ differences["agent_dir"] = self.agent_dir
54
+ if self.carrying != prev_state.carrying:
55
+ differences["carrying"] = self.carrying
56
+ if self.step_count != prev_state.step_count:
57
+ differences["step_count"] = self.step_count
58
+ if self.mission != prev_state.mission:
59
+ differences["mission"] = self.mission
60
+ if self.terminated != prev_state.terminated:
61
+ differences["terminated"] = self.terminated
62
+ return differences
63
+
64
+
65
+ @dataclass
66
+ class MiniGridPrivateState:
67
+ """Private state of the MiniGrid environment."""
68
+
69
+ reward_last: float = 0.0
70
+ total_reward: float = 0.0
71
+ terminated: bool = False
72
+ truncated: bool = False
73
+ info: Dict[str, Any] = field(default_factory=dict)
74
+ # Debug information
75
+ last_action: Optional[str] = None
76
+ last_action_result: Optional[str] = (
77
+ None # "success", "blocked_by_wall", "blocked_by_boundary", etc.
78
+ )
79
+ position_before_action: Optional[Tuple[int, int]] = None
80
+ position_after_action: Optional[Tuple[int, int]] = None
81
+ debug_message: Optional[str] = None
82
+
83
+ def diff(self, prev_state: "MiniGridPrivateState") -> Dict[str, Any]:
84
+ """Track changes between states."""
85
+ differences = {}
86
+ if self.reward_last != prev_state.reward_last:
87
+ differences["reward_last"] = self.reward_last
88
+ if self.total_reward != prev_state.total_reward:
89
+ differences["total_reward"] = self.total_reward
90
+ if self.terminated != prev_state.terminated:
91
+ differences["terminated"] = self.terminated
92
+ if self.truncated != prev_state.truncated:
93
+ differences["truncated"] = self.truncated
94
+ if self.info != prev_state.info:
95
+ differences["info"] = self.info
96
+ return differences
97
+
98
+
99
+ @dataclass
100
+ class MiniGridEngineSnapshot(StatefulEngineSnapshot):
101
+ """Serialization container for MiniGrid engine."""
102
+
103
+ task_instance_dict: Dict
104
+ engine_snapshot: Dict
105
+
106
+
107
+ class MiniGridGoalReachedComponent(RewardComponent):
108
+ """Reward component for reaching the goal."""
109
+
110
+ def __init__(self, reward_value: float = 1.0):
111
+ self.reward_value = reward_value
112
+
113
+ async def score(self, state: MiniGridPublicState, action: Any) -> float:
114
+ """Calculate reward based on whether goal was reached."""
115
+ # Note: We check the private state info for success in the engine
116
+ return 0.0 # Reward is handled by the base environment
117
+
118
+
119
+ class MiniGridStepPenaltyComponent(RewardComponent):
120
+ """Penalty for each step taken."""
121
+
122
+ def __init__(self, penalty: float = -0.01):
123
+ self.penalty = penalty
124
+
125
+ async def score(self, state: MiniGridPublicState, action: Any) -> float:
126
+ """Apply small penalty for each step."""
127
+ return self.penalty
128
+
129
+
130
+ class MiniGridObservationCallable(GetObservationCallable):
131
+ """Default observation callable for MiniGrid."""
132
+
133
+ async def get_observation(
134
+ self, pub: MiniGridPublicState, priv: MiniGridPrivateState
135
+ ) -> InternalObservation:
136
+ """Generate text-based observation of the MiniGrid state."""
137
+ # Create text representation of the grid
138
+ grid_lines = []
139
+ grid_array = pub.grid_array
140
+ height, width = grid_array.shape[:2]
141
+
142
+ # Object type mapping - use actual MiniGrid constants
143
+ # Note: OBJECT_TO_IDX gives us the correct mapping
144
+ # We need to create the reverse mapping: idx -> symbol
145
+
146
+ # Direction symbols
147
+ dir_symbols = ["→", "↓", "←", "↑"]
148
+
149
+ # Build grid visualization
150
+ for y in range(height):
151
+ line = []
152
+ for x in range(width):
153
+ obj_type = grid_array[y, x, 0]
154
+
155
+ if (x, y) == pub.agent_pos:
156
+ # Show agent with direction
157
+ line.append(dir_symbols[pub.agent_dir])
158
+ elif obj_type == OBJECT_TO_IDX["empty"]: # empty (1)
159
+ line.append(".")
160
+ elif obj_type == OBJECT_TO_IDX["wall"]: # wall (2)
161
+ line.append("#")
162
+ elif obj_type == OBJECT_TO_IDX["goal"]: # goal (8)
163
+ line.append("G")
164
+ elif obj_type == OBJECT_TO_IDX["key"]: # key (5)
165
+ line.append("K")
166
+ elif obj_type == OBJECT_TO_IDX["door"]: # door (4)
167
+ line.append("D")
168
+ elif obj_type == OBJECT_TO_IDX["ball"]: # ball (6)
169
+ line.append("B")
170
+ elif obj_type == OBJECT_TO_IDX["lava"]: # lava (9)
171
+ line.append("L")
172
+ elif obj_type == OBJECT_TO_IDX["unseen"]: # unseen (0)
173
+ line.append("?")
174
+ else:
175
+ line.append("?")
176
+ grid_lines.append(" ".join(line))
177
+
178
+ # Build complete observation
179
+ observation_parts = [
180
+ f"Mission: {pub.mission}",
181
+ f"Steps: {pub.step_count}/{pub.max_steps}",
182
+ f"Agent Position: ({pub.agent_pos[0]}, {pub.agent_pos[1]})",
183
+ f"Agent Direction: {dir_symbols[pub.agent_dir]}",
184
+ ]
185
+
186
+ if pub.carrying:
187
+ observation_parts.append(f"Carrying: {pub.carrying['type']} ({pub.carrying['color']})")
188
+
189
+ observation_parts.append("\nGrid:")
190
+ observation_parts.extend(grid_lines)
191
+
192
+ observation_parts.append(
193
+ "\nLegend: # = wall, . = empty, G = goal, K = key, D = door, B = ball, L = lava"
194
+ )
195
+ observation_parts.append("Agent directions: → = right, ↓ = down, ← = left, ↑ = up")
196
+
197
+ # Add debug information if available - make it more prominent
198
+ if priv.debug_message or (priv.last_action_result and priv.last_action_result != "success"):
199
+ observation_parts.append("\n" + "=" * 50)
200
+ observation_parts.append("🚨 CRITICAL FEEDBACK FROM LAST ACTION:")
201
+ if priv.debug_message:
202
+ observation_parts.append(f" {priv.debug_message}")
203
+ if priv.last_action_result and priv.last_action_result != "success":
204
+ observation_parts.append(f" Result: {priv.last_action_result}")
205
+ observation_parts.append(
206
+ " ⚠️ IMPORTANT: If blocked, you MUST turn or try different action!"
207
+ )
208
+ observation_parts.append("=" * 50)
209
+
210
+ text_obs = "\n".join(observation_parts)
211
+
212
+ observation: InternalObservation = {
213
+ "observation": text_obs,
214
+ "terminated": pub.terminated,
215
+ "truncated": priv.truncated,
216
+ "reward_last": priv.reward_last,
217
+ "total_reward": priv.total_reward,
218
+ # Include debug info in observation dict too
219
+ "last_action": priv.last_action,
220
+ "last_action_result": priv.last_action_result,
221
+ "debug_message": priv.debug_message,
222
+ }
223
+
224
+ return observation
225
+
226
+
227
+ class MiniGridCheckpointObservationCallable(GetObservationCallable):
228
+ """Checkpoint observation callable for MiniGrid."""
229
+
230
+ async def get_observation(
231
+ self, pub: MiniGridPublicState, priv: MiniGridPrivateState
232
+ ) -> InternalObservation:
233
+ """Generate checkpoint observation."""
234
+ observation: InternalObservation = {
235
+ "mission": pub.mission,
236
+ "final_position": pub.agent_pos,
237
+ "total_steps": pub.step_count,
238
+ "total_reward": priv.total_reward,
239
+ "terminated": pub.terminated,
240
+ "truncated": priv.truncated,
241
+ "success": priv.info.get("success", False),
242
+ }
243
+ return observation
244
+
245
+
246
+ class MiniGridEngine(StatefulEngine, IReproducibleEngine):
247
+ """Engine for MiniGrid environments."""
248
+
249
+ def __init__(
250
+ self,
251
+ task_instance: TaskInstance,
252
+ render_mode: Optional[str] = None,
253
+ ):
254
+ """Initialize MiniGrid engine.
255
+
256
+ Args:
257
+ task_instance: Task instance containing configuration
258
+ render_mode: Rendering mode for the environment
259
+ """
260
+ self.task_instance = task_instance
261
+ self.render_mode = render_mode
262
+
263
+ # Get environment configuration from task instance
264
+ env_name = None
265
+ seed = None
266
+ difficulty = None
267
+
268
+ # First try to get explicit configuration from metadata
269
+ if hasattr(task_instance, "metadata"):
270
+ if hasattr(task_instance.metadata, "env_name"):
271
+ env_name = task_instance.metadata.env_name
272
+ if hasattr(task_instance.metadata, "seed"):
273
+ seed = task_instance.metadata.seed
274
+ if hasattr(task_instance.metadata, "difficulty"):
275
+ difficulty = task_instance.metadata.difficulty
276
+
277
+ # If no explicit env_name but we have a seed, use seed mapping
278
+ if env_name is None and seed is not None:
279
+ env_name = get_environment_from_seed(seed)
280
+ if difficulty is None:
281
+ difficulty = get_difficulty_from_seed(seed)
282
+
283
+ # If still no environment name, check if we can extract seed from config
284
+ if env_name is None and hasattr(task_instance, "initial_engine_snapshot"):
285
+ snapshot = task_instance.initial_engine_snapshot
286
+ if snapshot and isinstance(snapshot, dict):
287
+ config_seed = snapshot.get("seed")
288
+ if config_seed is not None:
289
+ seed = config_seed
290
+ env_name = get_environment_from_seed(seed)
291
+ if difficulty is None:
292
+ difficulty = get_difficulty_from_seed(seed)
293
+
294
+ # Final fallback to default environment
295
+ if env_name is None:
296
+ env_name = "MiniGrid-Empty-5x5-v0"
297
+ seed = 0 # Ensure we have a seed for reproducibility
298
+
299
+ # Validate the environment name
300
+ if not validate_environment_name(env_name):
301
+ print(f"Warning: Unknown environment '{env_name}', falling back to default")
302
+ env_name = "MiniGrid-Empty-5x5-v0"
303
+ seed = 0
304
+
305
+ self.env_name = env_name
306
+ self.seed = seed
307
+ self.difficulty = difficulty
308
+
309
+ # Create the environment
310
+ self.env = gym.make(self.env_name, render_mode=self.render_mode)
311
+
312
+ # Initialize reward stack
313
+ self.reward_stack = RewardStack(
314
+ [
315
+ MiniGridStepPenaltyComponent(),
316
+ ]
317
+ )
318
+
319
+ # Initialize state tracking
320
+ self.total_reward = 0.0
321
+ self._initialized = False
322
+
323
+ def _grid_to_array(self) -> np.ndarray:
324
+ """Convert MiniGrid grid to numpy array."""
325
+ # Access the unwrapped environment
326
+ unwrapped = self.env.unwrapped
327
+
328
+ width, height = unwrapped.grid.width, unwrapped.grid.height
329
+ grid_array = np.zeros((height, width, 3), dtype=np.uint8)
330
+
331
+ for i in range(height):
332
+ for j in range(width):
333
+ cell = unwrapped.grid.get(j, i)
334
+ if cell is None:
335
+ grid_array[i, j] = [OBJECT_TO_IDX["empty"], 0, 0]
336
+ else:
337
+ grid_array[i, j] = [
338
+ OBJECT_TO_IDX.get(cell.type, 0),
339
+ COLOR_TO_IDX.get(cell.color, 0),
340
+ STATE_TO_IDX.get(getattr(cell, "state", 0), 0)
341
+ if hasattr(cell, "state")
342
+ else 0,
343
+ ]
344
+
345
+ # Add agent to grid
346
+ if unwrapped.agent_pos is not None:
347
+ ax, ay = unwrapped.agent_pos
348
+ grid_array[ay, ax] = [
349
+ OBJECT_TO_IDX["agent"],
350
+ COLOR_TO_IDX["red"],
351
+ unwrapped.agent_dir,
352
+ ]
353
+
354
+ return grid_array
355
+
356
+ def _extract_public_state(self, terminated: bool = False) -> MiniGridPublicState:
357
+ """Extract public state from the current environment."""
358
+ # Access the unwrapped environment
359
+ unwrapped = self.env.unwrapped
360
+
361
+ # Get grid array representation
362
+ grid_array = self._grid_to_array()
363
+
364
+ # Get carrying object info
365
+ carrying = None
366
+ if unwrapped.carrying:
367
+ carrying = {
368
+ "type": unwrapped.carrying.type,
369
+ "color": unwrapped.carrying.color,
370
+ }
371
+
372
+ return MiniGridPublicState(
373
+ grid_array=grid_array,
374
+ agent_pos=tuple(unwrapped.agent_pos),
375
+ agent_dir=unwrapped.agent_dir,
376
+ carrying=carrying,
377
+ step_count=unwrapped.step_count,
378
+ max_steps=unwrapped.max_steps,
379
+ mission=unwrapped.mission,
380
+ terminated=terminated,
381
+ )
382
+
383
+ async def _reset_engine(
384
+ self, *, seed: int | None = None
385
+ ) -> Tuple[MiniGridPrivateState, MiniGridPublicState]:
386
+ """Reset to initial state."""
387
+ # Reset environment
388
+ if seed is not None:
389
+ obs, info = self.env.reset(seed=seed)
390
+ elif self.seed is not None:
391
+ obs, info = self.env.reset(seed=self.seed)
392
+ else:
393
+ obs, info = self.env.reset()
394
+
395
+ # Reset tracking
396
+ self.total_reward = 0.0
397
+ self._initialized = True
398
+
399
+ # Create states
400
+ public_state = self._extract_public_state(terminated=False)
401
+ private_state = MiniGridPrivateState(
402
+ reward_last=0.0,
403
+ total_reward=0.0,
404
+ terminated=False,
405
+ truncated=False,
406
+ info=info,
407
+ )
408
+
409
+ return private_state, public_state
410
+
411
+ async def _step_engine(self, action: int) -> Tuple[MiniGridPrivateState, MiniGridPublicState]:
412
+ """Execute one step/action."""
413
+ if not self._initialized:
414
+ raise RuntimeError("Engine not initialized. Call _reset_engine first.")
415
+
416
+ # Validate action
417
+ if not isinstance(action, int) or action < 0 or action > 6:
418
+ raise ValueError(f"Invalid action: {action}. Must be integer 0-6.")
419
+
420
+ # Get position before action
421
+ unwrapped = self.env.unwrapped
422
+ pos_before = unwrapped.agent_pos
423
+ dir_before = unwrapped.agent_dir
424
+
425
+ # Map action to name
426
+ action_names = {
427
+ 0: "left",
428
+ 1: "right",
429
+ 2: "forward",
430
+ 3: "pickup",
431
+ 4: "drop",
432
+ 5: "toggle",
433
+ 6: "done",
434
+ }
435
+ action_name = action_names.get(action, f"unknown({action})")
436
+
437
+ # Execute action in environment
438
+ obs, reward, terminated, truncated, info = self.env.step(action)
439
+
440
+ # Get position after action
441
+ pos_after = unwrapped.agent_pos
442
+ dir_after = unwrapped.agent_dir
443
+
444
+ # Determine action result
445
+ action_result = "success"
446
+ debug_message = f"Action: {action_name}"
447
+
448
+ if action in [0, 1]: # Turn actions
449
+ if dir_before != dir_after:
450
+ action_result = "turned"
451
+ debug_message = f"Turned {action_name}: direction {dir_before} -> {dir_after}"
452
+ else:
453
+ action_result = "turn_failed"
454
+ debug_message = f"Turn {action_name} failed"
455
+ elif action == 2: # Forward action
456
+ if pos_before == pos_after:
457
+ # Check what blocked movement
458
+ fwd_pos = unwrapped.front_pos
459
+ if (
460
+ fwd_pos[0] < 0
461
+ or fwd_pos[0] >= unwrapped.width
462
+ or fwd_pos[1] < 0
463
+ or fwd_pos[1] >= unwrapped.height
464
+ ):
465
+ action_result = "blocked_by_boundary"
466
+ debug_message = f"Forward blocked by grid boundary at {fwd_pos}"
467
+ else:
468
+ cell = unwrapped.grid.get(*fwd_pos)
469
+ if cell is not None and cell.type == "wall":
470
+ action_result = "blocked_by_wall"
471
+ debug_message = f"Forward blocked by wall at {fwd_pos}"
472
+ elif cell is not None and cell.type == "lava":
473
+ action_result = "blocked_by_lava"
474
+ debug_message = f"Forward blocked by lava at {fwd_pos}"
475
+ else:
476
+ action_result = "blocked_by_object"
477
+ debug_message = (
478
+ f"Forward blocked by {cell.type if cell else 'unknown'} at {fwd_pos}"
479
+ )
480
+ else:
481
+ action_result = "moved"
482
+ debug_message = f"Moved forward: {pos_before} -> {pos_after}"
483
+
484
+ # Calculate custom rewards
485
+ public_state = self._extract_public_state(terminated=terminated)
486
+ custom_reward = await self.reward_stack.step_reward(public_state, action)
487
+
488
+ # Use environment reward as base, add custom rewards
489
+ total_step_reward = reward + custom_reward
490
+ self.total_reward += total_step_reward
491
+
492
+ # Create states with debug info
493
+ private_state = MiniGridPrivateState(
494
+ reward_last=total_step_reward,
495
+ total_reward=self.total_reward,
496
+ terminated=terminated,
497
+ truncated=truncated,
498
+ info=info,
499
+ last_action=action_name,
500
+ last_action_result=action_result,
501
+ position_before_action=pos_before,
502
+ position_after_action=pos_after,
503
+ debug_message=debug_message,
504
+ )
505
+
506
+ return private_state, public_state
507
+
508
+ async def _serialize_engine(self) -> MiniGridEngineSnapshot:
509
+ """Serialize current state."""
510
+ engine_snapshot = {
511
+ "env_name": self.env_name,
512
+ "seed": self.seed,
513
+ "total_reward": self.total_reward,
514
+ "initialized": self._initialized,
515
+ # Note: Full environment state serialization would require
516
+ # MiniGrid to support it, which it doesn't by default
517
+ }
518
+
519
+ task_dict = {}
520
+ if hasattr(self.task_instance, "serialize"):
521
+ task_dict = await self.task_instance.serialize()
522
+
523
+ return MiniGridEngineSnapshot(
524
+ task_instance_dict=task_dict,
525
+ engine_snapshot=engine_snapshot,
526
+ )
527
+
528
+ @classmethod
529
+ async def _deserialize_engine(cls, snapshot: MiniGridEngineSnapshot) -> "MiniGridEngine":
530
+ """Restore from serialized state."""
531
+ # Recreate task instance
532
+ task_instance = None
533
+ if snapshot.task_instance_dict:
534
+ # This would need proper task instance deserialization
535
+ task_instance = TaskInstance(**snapshot.task_instance_dict)
536
+
537
+ # Create engine
538
+ engine = cls(task_instance)
539
+
540
+ # Restore state
541
+ engine_data = snapshot.engine_snapshot
542
+ engine.total_reward = engine_data.get("total_reward", 0.0)
543
+ engine._initialized = engine_data.get("initialized", False)
544
+
545
+ return engine
546
+
547
+ def get_current_states_for_observation(
548
+ self,
549
+ ) -> Tuple[MiniGridPrivateState, MiniGridPublicState]:
550
+ """Get current states without advancing."""
551
+ if not self._initialized:
552
+ # Return empty states
553
+ return (
554
+ MiniGridPrivateState(),
555
+ MiniGridPublicState(
556
+ grid_array=np.zeros((5, 5, 3)),
557
+ agent_pos=(0, 0),
558
+ agent_dir=0,
559
+ ),
560
+ )
561
+
562
+ # Access the unwrapped environment
563
+ unwrapped = self.env.unwrapped
564
+
565
+ # Get current state
566
+ terminated = unwrapped.step_count >= unwrapped.max_steps
567
+ public_state = self._extract_public_state(terminated=terminated)
568
+
569
+ private_state = MiniGridPrivateState(
570
+ reward_last=0.0,
571
+ total_reward=self.total_reward,
572
+ terminated=terminated,
573
+ truncated=terminated,
574
+ info={},
575
+ )
576
+
577
+ return private_state, public_state
578
+
579
+ def get_available_actions(self) -> List[Tuple[int, str]]:
580
+ """Get list of available actions with descriptions."""
581
+ return [
582
+ (0, "turn left"), # Action 0 is counter-clockwise (left)
583
+ (1, "turn right"), # Action 1 is clockwise (right)
584
+ (2, "move forward"),
585
+ (3, "pickup"),
586
+ (4, "drop"),
587
+ (5, "toggle/activate"),
588
+ (6, "done (complete mission)"),
589
+ ]