synth-ai 0.2.4.dev3__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/tracing_v3/examples/basic_usage.py +188 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +105 -6
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,589 @@
|
|
1
|
+
"""MiniGrid Engine implementation.
|
2
|
+
|
3
|
+
This module provides a wrapper around Gymnasium MiniGrid environments
|
4
|
+
with full state management and serialization capabilities.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
import json
|
10
|
+
from dataclasses import dataclass, field
|
11
|
+
from typing import Any, Dict, Optional, Tuple, Union, List
|
12
|
+
|
13
|
+
import gymnasium as gym
|
14
|
+
import numpy as np
|
15
|
+
from minigrid.core.constants import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX
|
16
|
+
|
17
|
+
from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
|
18
|
+
from synth_ai.environments.reproducibility.core import IReproducibleEngine
|
19
|
+
from synth_ai.environments.environment.rewards.core import RewardComponent, RewardStack
|
20
|
+
from synth_ai.environments.environment.shared_engine import (
|
21
|
+
GetObservationCallable,
|
22
|
+
InternalObservation,
|
23
|
+
)
|
24
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
25
|
+
from synth_ai.environments.examples.minigrid.environment_mapping import (
|
26
|
+
get_environment_from_seed,
|
27
|
+
get_difficulty_from_seed,
|
28
|
+
validate_environment_name,
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class MiniGridPublicState:
|
34
|
+
"""Public state of the MiniGrid environment."""
|
35
|
+
|
36
|
+
grid_array: np.ndarray # The grid as a numpy array
|
37
|
+
agent_pos: Tuple[int, int] # Agent position (x, y)
|
38
|
+
agent_dir: int # Agent direction (0=right, 1=down, 2=left, 3=up)
|
39
|
+
carrying: Optional[Dict[str, Any]] = None # Object being carried
|
40
|
+
step_count: int = 0
|
41
|
+
max_steps: int = 1000
|
42
|
+
mission: str = ""
|
43
|
+
terminated: bool = False
|
44
|
+
|
45
|
+
def diff(self, prev_state: "MiniGridPublicState") -> Dict[str, Any]:
|
46
|
+
"""Track changes between states."""
|
47
|
+
differences = {}
|
48
|
+
if not np.array_equal(self.grid_array, prev_state.grid_array):
|
49
|
+
differences["grid_array"] = self.grid_array.tolist()
|
50
|
+
if self.agent_pos != prev_state.agent_pos:
|
51
|
+
differences["agent_pos"] = self.agent_pos
|
52
|
+
if self.agent_dir != prev_state.agent_dir:
|
53
|
+
differences["agent_dir"] = self.agent_dir
|
54
|
+
if self.carrying != prev_state.carrying:
|
55
|
+
differences["carrying"] = self.carrying
|
56
|
+
if self.step_count != prev_state.step_count:
|
57
|
+
differences["step_count"] = self.step_count
|
58
|
+
if self.mission != prev_state.mission:
|
59
|
+
differences["mission"] = self.mission
|
60
|
+
if self.terminated != prev_state.terminated:
|
61
|
+
differences["terminated"] = self.terminated
|
62
|
+
return differences
|
63
|
+
|
64
|
+
|
65
|
+
@dataclass
|
66
|
+
class MiniGridPrivateState:
|
67
|
+
"""Private state of the MiniGrid environment."""
|
68
|
+
|
69
|
+
reward_last: float = 0.0
|
70
|
+
total_reward: float = 0.0
|
71
|
+
terminated: bool = False
|
72
|
+
truncated: bool = False
|
73
|
+
info: Dict[str, Any] = field(default_factory=dict)
|
74
|
+
# Debug information
|
75
|
+
last_action: Optional[str] = None
|
76
|
+
last_action_result: Optional[str] = (
|
77
|
+
None # "success", "blocked_by_wall", "blocked_by_boundary", etc.
|
78
|
+
)
|
79
|
+
position_before_action: Optional[Tuple[int, int]] = None
|
80
|
+
position_after_action: Optional[Tuple[int, int]] = None
|
81
|
+
debug_message: Optional[str] = None
|
82
|
+
|
83
|
+
def diff(self, prev_state: "MiniGridPrivateState") -> Dict[str, Any]:
|
84
|
+
"""Track changes between states."""
|
85
|
+
differences = {}
|
86
|
+
if self.reward_last != prev_state.reward_last:
|
87
|
+
differences["reward_last"] = self.reward_last
|
88
|
+
if self.total_reward != prev_state.total_reward:
|
89
|
+
differences["total_reward"] = self.total_reward
|
90
|
+
if self.terminated != prev_state.terminated:
|
91
|
+
differences["terminated"] = self.terminated
|
92
|
+
if self.truncated != prev_state.truncated:
|
93
|
+
differences["truncated"] = self.truncated
|
94
|
+
if self.info != prev_state.info:
|
95
|
+
differences["info"] = self.info
|
96
|
+
return differences
|
97
|
+
|
98
|
+
|
99
|
+
@dataclass
|
100
|
+
class MiniGridEngineSnapshot(StatefulEngineSnapshot):
|
101
|
+
"""Serialization container for MiniGrid engine."""
|
102
|
+
|
103
|
+
task_instance_dict: Dict
|
104
|
+
engine_snapshot: Dict
|
105
|
+
|
106
|
+
|
107
|
+
class MiniGridGoalReachedComponent(RewardComponent):
|
108
|
+
"""Reward component for reaching the goal."""
|
109
|
+
|
110
|
+
def __init__(self, reward_value: float = 1.0):
|
111
|
+
self.reward_value = reward_value
|
112
|
+
|
113
|
+
async def score(self, state: MiniGridPublicState, action: Any) -> float:
|
114
|
+
"""Calculate reward based on whether goal was reached."""
|
115
|
+
# Note: We check the private state info for success in the engine
|
116
|
+
return 0.0 # Reward is handled by the base environment
|
117
|
+
|
118
|
+
|
119
|
+
class MiniGridStepPenaltyComponent(RewardComponent):
|
120
|
+
"""Penalty for each step taken."""
|
121
|
+
|
122
|
+
def __init__(self, penalty: float = -0.01):
|
123
|
+
self.penalty = penalty
|
124
|
+
|
125
|
+
async def score(self, state: MiniGridPublicState, action: Any) -> float:
|
126
|
+
"""Apply small penalty for each step."""
|
127
|
+
return self.penalty
|
128
|
+
|
129
|
+
|
130
|
+
class MiniGridObservationCallable(GetObservationCallable):
|
131
|
+
"""Default observation callable for MiniGrid."""
|
132
|
+
|
133
|
+
async def get_observation(
|
134
|
+
self, pub: MiniGridPublicState, priv: MiniGridPrivateState
|
135
|
+
) -> InternalObservation:
|
136
|
+
"""Generate text-based observation of the MiniGrid state."""
|
137
|
+
# Create text representation of the grid
|
138
|
+
grid_lines = []
|
139
|
+
grid_array = pub.grid_array
|
140
|
+
height, width = grid_array.shape[:2]
|
141
|
+
|
142
|
+
# Object type mapping - use actual MiniGrid constants
|
143
|
+
# Note: OBJECT_TO_IDX gives us the correct mapping
|
144
|
+
# We need to create the reverse mapping: idx -> symbol
|
145
|
+
|
146
|
+
# Direction symbols
|
147
|
+
dir_symbols = ["→", "↓", "←", "↑"]
|
148
|
+
|
149
|
+
# Build grid visualization
|
150
|
+
for y in range(height):
|
151
|
+
line = []
|
152
|
+
for x in range(width):
|
153
|
+
obj_type = grid_array[y, x, 0]
|
154
|
+
|
155
|
+
if (x, y) == pub.agent_pos:
|
156
|
+
# Show agent with direction
|
157
|
+
line.append(dir_symbols[pub.agent_dir])
|
158
|
+
elif obj_type == OBJECT_TO_IDX["empty"]: # empty (1)
|
159
|
+
line.append(".")
|
160
|
+
elif obj_type == OBJECT_TO_IDX["wall"]: # wall (2)
|
161
|
+
line.append("#")
|
162
|
+
elif obj_type == OBJECT_TO_IDX["goal"]: # goal (8)
|
163
|
+
line.append("G")
|
164
|
+
elif obj_type == OBJECT_TO_IDX["key"]: # key (5)
|
165
|
+
line.append("K")
|
166
|
+
elif obj_type == OBJECT_TO_IDX["door"]: # door (4)
|
167
|
+
line.append("D")
|
168
|
+
elif obj_type == OBJECT_TO_IDX["ball"]: # ball (6)
|
169
|
+
line.append("B")
|
170
|
+
elif obj_type == OBJECT_TO_IDX["lava"]: # lava (9)
|
171
|
+
line.append("L")
|
172
|
+
elif obj_type == OBJECT_TO_IDX["unseen"]: # unseen (0)
|
173
|
+
line.append("?")
|
174
|
+
else:
|
175
|
+
line.append("?")
|
176
|
+
grid_lines.append(" ".join(line))
|
177
|
+
|
178
|
+
# Build complete observation
|
179
|
+
observation_parts = [
|
180
|
+
f"Mission: {pub.mission}",
|
181
|
+
f"Steps: {pub.step_count}/{pub.max_steps}",
|
182
|
+
f"Agent Position: ({pub.agent_pos[0]}, {pub.agent_pos[1]})",
|
183
|
+
f"Agent Direction: {dir_symbols[pub.agent_dir]}",
|
184
|
+
]
|
185
|
+
|
186
|
+
if pub.carrying:
|
187
|
+
observation_parts.append(f"Carrying: {pub.carrying['type']} ({pub.carrying['color']})")
|
188
|
+
|
189
|
+
observation_parts.append("\nGrid:")
|
190
|
+
observation_parts.extend(grid_lines)
|
191
|
+
|
192
|
+
observation_parts.append(
|
193
|
+
"\nLegend: # = wall, . = empty, G = goal, K = key, D = door, B = ball, L = lava"
|
194
|
+
)
|
195
|
+
observation_parts.append("Agent directions: → = right, ↓ = down, ← = left, ↑ = up")
|
196
|
+
|
197
|
+
# Add debug information if available - make it more prominent
|
198
|
+
if priv.debug_message or (priv.last_action_result and priv.last_action_result != "success"):
|
199
|
+
observation_parts.append("\n" + "=" * 50)
|
200
|
+
observation_parts.append("🚨 CRITICAL FEEDBACK FROM LAST ACTION:")
|
201
|
+
if priv.debug_message:
|
202
|
+
observation_parts.append(f" {priv.debug_message}")
|
203
|
+
if priv.last_action_result and priv.last_action_result != "success":
|
204
|
+
observation_parts.append(f" Result: {priv.last_action_result}")
|
205
|
+
observation_parts.append(
|
206
|
+
" ⚠️ IMPORTANT: If blocked, you MUST turn or try different action!"
|
207
|
+
)
|
208
|
+
observation_parts.append("=" * 50)
|
209
|
+
|
210
|
+
text_obs = "\n".join(observation_parts)
|
211
|
+
|
212
|
+
observation: InternalObservation = {
|
213
|
+
"observation": text_obs,
|
214
|
+
"terminated": pub.terminated,
|
215
|
+
"truncated": priv.truncated,
|
216
|
+
"reward_last": priv.reward_last,
|
217
|
+
"total_reward": priv.total_reward,
|
218
|
+
# Include debug info in observation dict too
|
219
|
+
"last_action": priv.last_action,
|
220
|
+
"last_action_result": priv.last_action_result,
|
221
|
+
"debug_message": priv.debug_message,
|
222
|
+
}
|
223
|
+
|
224
|
+
return observation
|
225
|
+
|
226
|
+
|
227
|
+
class MiniGridCheckpointObservationCallable(GetObservationCallable):
|
228
|
+
"""Checkpoint observation callable for MiniGrid."""
|
229
|
+
|
230
|
+
async def get_observation(
|
231
|
+
self, pub: MiniGridPublicState, priv: MiniGridPrivateState
|
232
|
+
) -> InternalObservation:
|
233
|
+
"""Generate checkpoint observation."""
|
234
|
+
observation: InternalObservation = {
|
235
|
+
"mission": pub.mission,
|
236
|
+
"final_position": pub.agent_pos,
|
237
|
+
"total_steps": pub.step_count,
|
238
|
+
"total_reward": priv.total_reward,
|
239
|
+
"terminated": pub.terminated,
|
240
|
+
"truncated": priv.truncated,
|
241
|
+
"success": priv.info.get("success", False),
|
242
|
+
}
|
243
|
+
return observation
|
244
|
+
|
245
|
+
|
246
|
+
class MiniGridEngine(StatefulEngine, IReproducibleEngine):
|
247
|
+
"""Engine for MiniGrid environments."""
|
248
|
+
|
249
|
+
def __init__(
|
250
|
+
self,
|
251
|
+
task_instance: TaskInstance,
|
252
|
+
render_mode: Optional[str] = None,
|
253
|
+
):
|
254
|
+
"""Initialize MiniGrid engine.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
task_instance: Task instance containing configuration
|
258
|
+
render_mode: Rendering mode for the environment
|
259
|
+
"""
|
260
|
+
self.task_instance = task_instance
|
261
|
+
self.render_mode = render_mode
|
262
|
+
|
263
|
+
# Get environment configuration from task instance
|
264
|
+
env_name = None
|
265
|
+
seed = None
|
266
|
+
difficulty = None
|
267
|
+
|
268
|
+
# First try to get explicit configuration from metadata
|
269
|
+
if hasattr(task_instance, "metadata"):
|
270
|
+
if hasattr(task_instance.metadata, "env_name"):
|
271
|
+
env_name = task_instance.metadata.env_name
|
272
|
+
if hasattr(task_instance.metadata, "seed"):
|
273
|
+
seed = task_instance.metadata.seed
|
274
|
+
if hasattr(task_instance.metadata, "difficulty"):
|
275
|
+
difficulty = task_instance.metadata.difficulty
|
276
|
+
|
277
|
+
# If no explicit env_name but we have a seed, use seed mapping
|
278
|
+
if env_name is None and seed is not None:
|
279
|
+
env_name = get_environment_from_seed(seed)
|
280
|
+
if difficulty is None:
|
281
|
+
difficulty = get_difficulty_from_seed(seed)
|
282
|
+
|
283
|
+
# If still no environment name, check if we can extract seed from config
|
284
|
+
if env_name is None and hasattr(task_instance, "initial_engine_snapshot"):
|
285
|
+
snapshot = task_instance.initial_engine_snapshot
|
286
|
+
if snapshot and isinstance(snapshot, dict):
|
287
|
+
config_seed = snapshot.get("seed")
|
288
|
+
if config_seed is not None:
|
289
|
+
seed = config_seed
|
290
|
+
env_name = get_environment_from_seed(seed)
|
291
|
+
if difficulty is None:
|
292
|
+
difficulty = get_difficulty_from_seed(seed)
|
293
|
+
|
294
|
+
# Final fallback to default environment
|
295
|
+
if env_name is None:
|
296
|
+
env_name = "MiniGrid-Empty-5x5-v0"
|
297
|
+
seed = 0 # Ensure we have a seed for reproducibility
|
298
|
+
|
299
|
+
# Validate the environment name
|
300
|
+
if not validate_environment_name(env_name):
|
301
|
+
print(f"Warning: Unknown environment '{env_name}', falling back to default")
|
302
|
+
env_name = "MiniGrid-Empty-5x5-v0"
|
303
|
+
seed = 0
|
304
|
+
|
305
|
+
self.env_name = env_name
|
306
|
+
self.seed = seed
|
307
|
+
self.difficulty = difficulty
|
308
|
+
|
309
|
+
# Create the environment
|
310
|
+
self.env = gym.make(self.env_name, render_mode=self.render_mode)
|
311
|
+
|
312
|
+
# Initialize reward stack
|
313
|
+
self.reward_stack = RewardStack(
|
314
|
+
[
|
315
|
+
MiniGridStepPenaltyComponent(),
|
316
|
+
]
|
317
|
+
)
|
318
|
+
|
319
|
+
# Initialize state tracking
|
320
|
+
self.total_reward = 0.0
|
321
|
+
self._initialized = False
|
322
|
+
|
323
|
+
def _grid_to_array(self) -> np.ndarray:
|
324
|
+
"""Convert MiniGrid grid to numpy array."""
|
325
|
+
# Access the unwrapped environment
|
326
|
+
unwrapped = self.env.unwrapped
|
327
|
+
|
328
|
+
width, height = unwrapped.grid.width, unwrapped.grid.height
|
329
|
+
grid_array = np.zeros((height, width, 3), dtype=np.uint8)
|
330
|
+
|
331
|
+
for i in range(height):
|
332
|
+
for j in range(width):
|
333
|
+
cell = unwrapped.grid.get(j, i)
|
334
|
+
if cell is None:
|
335
|
+
grid_array[i, j] = [OBJECT_TO_IDX["empty"], 0, 0]
|
336
|
+
else:
|
337
|
+
grid_array[i, j] = [
|
338
|
+
OBJECT_TO_IDX.get(cell.type, 0),
|
339
|
+
COLOR_TO_IDX.get(cell.color, 0),
|
340
|
+
STATE_TO_IDX.get(getattr(cell, "state", 0), 0)
|
341
|
+
if hasattr(cell, "state")
|
342
|
+
else 0,
|
343
|
+
]
|
344
|
+
|
345
|
+
# Add agent to grid
|
346
|
+
if unwrapped.agent_pos is not None:
|
347
|
+
ax, ay = unwrapped.agent_pos
|
348
|
+
grid_array[ay, ax] = [
|
349
|
+
OBJECT_TO_IDX["agent"],
|
350
|
+
COLOR_TO_IDX["red"],
|
351
|
+
unwrapped.agent_dir,
|
352
|
+
]
|
353
|
+
|
354
|
+
return grid_array
|
355
|
+
|
356
|
+
def _extract_public_state(self, terminated: bool = False) -> MiniGridPublicState:
|
357
|
+
"""Extract public state from the current environment."""
|
358
|
+
# Access the unwrapped environment
|
359
|
+
unwrapped = self.env.unwrapped
|
360
|
+
|
361
|
+
# Get grid array representation
|
362
|
+
grid_array = self._grid_to_array()
|
363
|
+
|
364
|
+
# Get carrying object info
|
365
|
+
carrying = None
|
366
|
+
if unwrapped.carrying:
|
367
|
+
carrying = {
|
368
|
+
"type": unwrapped.carrying.type,
|
369
|
+
"color": unwrapped.carrying.color,
|
370
|
+
}
|
371
|
+
|
372
|
+
return MiniGridPublicState(
|
373
|
+
grid_array=grid_array,
|
374
|
+
agent_pos=tuple(unwrapped.agent_pos),
|
375
|
+
agent_dir=unwrapped.agent_dir,
|
376
|
+
carrying=carrying,
|
377
|
+
step_count=unwrapped.step_count,
|
378
|
+
max_steps=unwrapped.max_steps,
|
379
|
+
mission=unwrapped.mission,
|
380
|
+
terminated=terminated,
|
381
|
+
)
|
382
|
+
|
383
|
+
async def _reset_engine(
|
384
|
+
self, *, seed: int | None = None
|
385
|
+
) -> Tuple[MiniGridPrivateState, MiniGridPublicState]:
|
386
|
+
"""Reset to initial state."""
|
387
|
+
# Reset environment
|
388
|
+
if seed is not None:
|
389
|
+
obs, info = self.env.reset(seed=seed)
|
390
|
+
elif self.seed is not None:
|
391
|
+
obs, info = self.env.reset(seed=self.seed)
|
392
|
+
else:
|
393
|
+
obs, info = self.env.reset()
|
394
|
+
|
395
|
+
# Reset tracking
|
396
|
+
self.total_reward = 0.0
|
397
|
+
self._initialized = True
|
398
|
+
|
399
|
+
# Create states
|
400
|
+
public_state = self._extract_public_state(terminated=False)
|
401
|
+
private_state = MiniGridPrivateState(
|
402
|
+
reward_last=0.0,
|
403
|
+
total_reward=0.0,
|
404
|
+
terminated=False,
|
405
|
+
truncated=False,
|
406
|
+
info=info,
|
407
|
+
)
|
408
|
+
|
409
|
+
return private_state, public_state
|
410
|
+
|
411
|
+
async def _step_engine(self, action: int) -> Tuple[MiniGridPrivateState, MiniGridPublicState]:
|
412
|
+
"""Execute one step/action."""
|
413
|
+
if not self._initialized:
|
414
|
+
raise RuntimeError("Engine not initialized. Call _reset_engine first.")
|
415
|
+
|
416
|
+
# Validate action
|
417
|
+
if not isinstance(action, int) or action < 0 or action > 6:
|
418
|
+
raise ValueError(f"Invalid action: {action}. Must be integer 0-6.")
|
419
|
+
|
420
|
+
# Get position before action
|
421
|
+
unwrapped = self.env.unwrapped
|
422
|
+
pos_before = unwrapped.agent_pos
|
423
|
+
dir_before = unwrapped.agent_dir
|
424
|
+
|
425
|
+
# Map action to name
|
426
|
+
action_names = {
|
427
|
+
0: "left",
|
428
|
+
1: "right",
|
429
|
+
2: "forward",
|
430
|
+
3: "pickup",
|
431
|
+
4: "drop",
|
432
|
+
5: "toggle",
|
433
|
+
6: "done",
|
434
|
+
}
|
435
|
+
action_name = action_names.get(action, f"unknown({action})")
|
436
|
+
|
437
|
+
# Execute action in environment
|
438
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
439
|
+
|
440
|
+
# Get position after action
|
441
|
+
pos_after = unwrapped.agent_pos
|
442
|
+
dir_after = unwrapped.agent_dir
|
443
|
+
|
444
|
+
# Determine action result
|
445
|
+
action_result = "success"
|
446
|
+
debug_message = f"Action: {action_name}"
|
447
|
+
|
448
|
+
if action in [0, 1]: # Turn actions
|
449
|
+
if dir_before != dir_after:
|
450
|
+
action_result = "turned"
|
451
|
+
debug_message = f"Turned {action_name}: direction {dir_before} -> {dir_after}"
|
452
|
+
else:
|
453
|
+
action_result = "turn_failed"
|
454
|
+
debug_message = f"Turn {action_name} failed"
|
455
|
+
elif action == 2: # Forward action
|
456
|
+
if pos_before == pos_after:
|
457
|
+
# Check what blocked movement
|
458
|
+
fwd_pos = unwrapped.front_pos
|
459
|
+
if (
|
460
|
+
fwd_pos[0] < 0
|
461
|
+
or fwd_pos[0] >= unwrapped.width
|
462
|
+
or fwd_pos[1] < 0
|
463
|
+
or fwd_pos[1] >= unwrapped.height
|
464
|
+
):
|
465
|
+
action_result = "blocked_by_boundary"
|
466
|
+
debug_message = f"Forward blocked by grid boundary at {fwd_pos}"
|
467
|
+
else:
|
468
|
+
cell = unwrapped.grid.get(*fwd_pos)
|
469
|
+
if cell is not None and cell.type == "wall":
|
470
|
+
action_result = "blocked_by_wall"
|
471
|
+
debug_message = f"Forward blocked by wall at {fwd_pos}"
|
472
|
+
elif cell is not None and cell.type == "lava":
|
473
|
+
action_result = "blocked_by_lava"
|
474
|
+
debug_message = f"Forward blocked by lava at {fwd_pos}"
|
475
|
+
else:
|
476
|
+
action_result = "blocked_by_object"
|
477
|
+
debug_message = (
|
478
|
+
f"Forward blocked by {cell.type if cell else 'unknown'} at {fwd_pos}"
|
479
|
+
)
|
480
|
+
else:
|
481
|
+
action_result = "moved"
|
482
|
+
debug_message = f"Moved forward: {pos_before} -> {pos_after}"
|
483
|
+
|
484
|
+
# Calculate custom rewards
|
485
|
+
public_state = self._extract_public_state(terminated=terminated)
|
486
|
+
custom_reward = await self.reward_stack.step_reward(public_state, action)
|
487
|
+
|
488
|
+
# Use environment reward as base, add custom rewards
|
489
|
+
total_step_reward = reward + custom_reward
|
490
|
+
self.total_reward += total_step_reward
|
491
|
+
|
492
|
+
# Create states with debug info
|
493
|
+
private_state = MiniGridPrivateState(
|
494
|
+
reward_last=total_step_reward,
|
495
|
+
total_reward=self.total_reward,
|
496
|
+
terminated=terminated,
|
497
|
+
truncated=truncated,
|
498
|
+
info=info,
|
499
|
+
last_action=action_name,
|
500
|
+
last_action_result=action_result,
|
501
|
+
position_before_action=pos_before,
|
502
|
+
position_after_action=pos_after,
|
503
|
+
debug_message=debug_message,
|
504
|
+
)
|
505
|
+
|
506
|
+
return private_state, public_state
|
507
|
+
|
508
|
+
async def _serialize_engine(self) -> MiniGridEngineSnapshot:
|
509
|
+
"""Serialize current state."""
|
510
|
+
engine_snapshot = {
|
511
|
+
"env_name": self.env_name,
|
512
|
+
"seed": self.seed,
|
513
|
+
"total_reward": self.total_reward,
|
514
|
+
"initialized": self._initialized,
|
515
|
+
# Note: Full environment state serialization would require
|
516
|
+
# MiniGrid to support it, which it doesn't by default
|
517
|
+
}
|
518
|
+
|
519
|
+
task_dict = {}
|
520
|
+
if hasattr(self.task_instance, "serialize"):
|
521
|
+
task_dict = await self.task_instance.serialize()
|
522
|
+
|
523
|
+
return MiniGridEngineSnapshot(
|
524
|
+
task_instance_dict=task_dict,
|
525
|
+
engine_snapshot=engine_snapshot,
|
526
|
+
)
|
527
|
+
|
528
|
+
@classmethod
|
529
|
+
async def _deserialize_engine(cls, snapshot: MiniGridEngineSnapshot) -> "MiniGridEngine":
|
530
|
+
"""Restore from serialized state."""
|
531
|
+
# Recreate task instance
|
532
|
+
task_instance = None
|
533
|
+
if snapshot.task_instance_dict:
|
534
|
+
# This would need proper task instance deserialization
|
535
|
+
task_instance = TaskInstance(**snapshot.task_instance_dict)
|
536
|
+
|
537
|
+
# Create engine
|
538
|
+
engine = cls(task_instance)
|
539
|
+
|
540
|
+
# Restore state
|
541
|
+
engine_data = snapshot.engine_snapshot
|
542
|
+
engine.total_reward = engine_data.get("total_reward", 0.0)
|
543
|
+
engine._initialized = engine_data.get("initialized", False)
|
544
|
+
|
545
|
+
return engine
|
546
|
+
|
547
|
+
def get_current_states_for_observation(
|
548
|
+
self,
|
549
|
+
) -> Tuple[MiniGridPrivateState, MiniGridPublicState]:
|
550
|
+
"""Get current states without advancing."""
|
551
|
+
if not self._initialized:
|
552
|
+
# Return empty states
|
553
|
+
return (
|
554
|
+
MiniGridPrivateState(),
|
555
|
+
MiniGridPublicState(
|
556
|
+
grid_array=np.zeros((5, 5, 3)),
|
557
|
+
agent_pos=(0, 0),
|
558
|
+
agent_dir=0,
|
559
|
+
),
|
560
|
+
)
|
561
|
+
|
562
|
+
# Access the unwrapped environment
|
563
|
+
unwrapped = self.env.unwrapped
|
564
|
+
|
565
|
+
# Get current state
|
566
|
+
terminated = unwrapped.step_count >= unwrapped.max_steps
|
567
|
+
public_state = self._extract_public_state(terminated=terminated)
|
568
|
+
|
569
|
+
private_state = MiniGridPrivateState(
|
570
|
+
reward_last=0.0,
|
571
|
+
total_reward=self.total_reward,
|
572
|
+
terminated=terminated,
|
573
|
+
truncated=terminated,
|
574
|
+
info={},
|
575
|
+
)
|
576
|
+
|
577
|
+
return private_state, public_state
|
578
|
+
|
579
|
+
def get_available_actions(self) -> List[Tuple[int, str]]:
|
580
|
+
"""Get list of available actions with descriptions."""
|
581
|
+
return [
|
582
|
+
(0, "turn left"), # Action 0 is counter-clockwise (left)
|
583
|
+
(1, "turn right"), # Action 1 is clockwise (right)
|
584
|
+
(2, "move forward"),
|
585
|
+
(3, "pickup"),
|
586
|
+
(4, "drop"),
|
587
|
+
(5, "toggle/activate"),
|
588
|
+
(6, "done (complete mission)"),
|
589
|
+
]
|