synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +579 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +391 -0
- synth_ai/environments/examples/wordle/environment.py +154 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +222 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/environments/service/core_routes.py +38 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
- synth_ai/learning/prompts/mipro.py +273 -1
- synth_ai/learning/prompts/random_search.py +247 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
- synth_ai/lm/injection.py +81 -0
- synth_ai/lm/overrides.py +204 -0
- synth_ai/lm/provider_support/anthropic.py +39 -12
- synth_ai/lm/provider_support/openai.py +31 -4
- synth_ai/lm/vendors/core/anthropic_api.py +16 -0
- synth_ai/lm/vendors/openai_standard.py +35 -5
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,274 @@
|
|
1
|
+
"""MiniGrid Environment implementation.
|
2
|
+
|
3
|
+
This module provides a high-level interface for MiniGrid environments
|
4
|
+
with tool-based interaction and flexible observation generation.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
import json
|
10
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from pydantic import BaseModel, Field
|
14
|
+
|
15
|
+
from synth_ai.environments.environment.tools import AbstractTool, EnvToolCall, ToolResult
|
16
|
+
from synth_ai.environments.stateful.core import StatefulEnvironment
|
17
|
+
from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
|
18
|
+
from synth_ai.environments.environment.shared_engine import (
|
19
|
+
GetObservationCallable,
|
20
|
+
InternalObservation,
|
21
|
+
)
|
22
|
+
from synth_ai.environments.examples.minigrid.engine import (
|
23
|
+
MiniGridEngine,
|
24
|
+
MiniGridPublicState,
|
25
|
+
MiniGridPrivateState,
|
26
|
+
MiniGridObservationCallable,
|
27
|
+
MiniGridCheckpointObservationCallable,
|
28
|
+
)
|
29
|
+
from synth_ai.environments.tasks.core import TaskInstance
|
30
|
+
|
31
|
+
|
32
|
+
class MiniGridActionInput(BaseModel):
|
33
|
+
"""Input model for MiniGrid actions."""
|
34
|
+
|
35
|
+
action: str = Field(
|
36
|
+
...,
|
37
|
+
description="The action to take. Must be one of: 'left', 'right', 'forward', 'pickup', 'drop', 'toggle', 'done'",
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
class MiniGridInteractTool(AbstractTool):
|
42
|
+
"""Tool for interacting with the MiniGrid environment."""
|
43
|
+
|
44
|
+
name = "minigrid_act"
|
45
|
+
description = "Perform an action in the MiniGrid environment"
|
46
|
+
call_schema = MiniGridActionInput
|
47
|
+
result_schema = ToolResult
|
48
|
+
|
49
|
+
def __init__(self, engine: MiniGridEngine):
|
50
|
+
"""Initialize the tool with a MiniGrid engine."""
|
51
|
+
self.engine = engine
|
52
|
+
self.action_map = {
|
53
|
+
"left": 0, # Action 0 is counter-clockwise (left)
|
54
|
+
"right": 1, # Action 1 is clockwise (right)
|
55
|
+
"forward": 2,
|
56
|
+
"pickup": 3,
|
57
|
+
"drop": 4,
|
58
|
+
"toggle": 5,
|
59
|
+
"done": 6,
|
60
|
+
}
|
61
|
+
|
62
|
+
async def __call__(self, call: EnvToolCall) -> ToolResult:
|
63
|
+
"""Execute the action and return the result."""
|
64
|
+
try:
|
65
|
+
action_name = call.args.get("action", "").lower()
|
66
|
+
|
67
|
+
if action_name not in self.action_map:
|
68
|
+
return ToolResult(
|
69
|
+
ok=False,
|
70
|
+
error=f"Invalid action '{action_name}'. Valid actions are: {', '.join(self.action_map.keys())}",
|
71
|
+
payload={},
|
72
|
+
)
|
73
|
+
|
74
|
+
action = self.action_map[action_name]
|
75
|
+
|
76
|
+
# Execute the action
|
77
|
+
private_state, public_state = await self.engine._step_engine(action)
|
78
|
+
|
79
|
+
# Build response
|
80
|
+
response_parts = [f"Action '{action_name}' executed."]
|
81
|
+
|
82
|
+
if private_state.reward_last != 0:
|
83
|
+
response_parts.append(f"Reward: {private_state.reward_last:.2f}")
|
84
|
+
|
85
|
+
if private_state.terminated:
|
86
|
+
response_parts.append("Episode terminated!")
|
87
|
+
if private_state.info.get("success", False):
|
88
|
+
response_parts.append("Mission completed successfully!")
|
89
|
+
elif private_state.truncated:
|
90
|
+
response_parts.append("Episode truncated (max steps reached).")
|
91
|
+
|
92
|
+
return ToolResult(
|
93
|
+
ok=True,
|
94
|
+
payload={
|
95
|
+
"message": " ".join(response_parts),
|
96
|
+
"public_state": public_state,
|
97
|
+
"private_state": private_state,
|
98
|
+
},
|
99
|
+
)
|
100
|
+
except Exception as e:
|
101
|
+
return ToolResult(ok=False, error=str(e), payload={})
|
102
|
+
|
103
|
+
|
104
|
+
class MiniGridEnvironment(StatefulEnvironment, ReproducibleEnvironment[MiniGridEngine]):
|
105
|
+
"""High-level MiniGrid environment with tool-based interaction."""
|
106
|
+
|
107
|
+
def __init__(
|
108
|
+
self,
|
109
|
+
task_instance: TaskInstance,
|
110
|
+
custom_step_obs: Optional[GetObservationCallable] = None,
|
111
|
+
custom_ckpt_obs: Optional[GetObservationCallable] = None,
|
112
|
+
):
|
113
|
+
"""Initialize the MiniGrid environment.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
task_instance: Task instance containing configuration
|
117
|
+
custom_step_obs: Custom observation generator for steps
|
118
|
+
custom_ckpt_obs: Custom observation generator for checkpoints
|
119
|
+
"""
|
120
|
+
self.name = "MiniGridEnvironment"
|
121
|
+
self.task_instance = task_instance
|
122
|
+
self.custom_step_observation_callable = custom_step_obs or MiniGridObservationCallable()
|
123
|
+
self.custom_checkpoint_observation_callable = (
|
124
|
+
custom_ckpt_obs or MiniGridCheckpointObservationCallable()
|
125
|
+
)
|
126
|
+
|
127
|
+
# Create engine
|
128
|
+
self.engine = MiniGridEngine(task_instance)
|
129
|
+
|
130
|
+
# Initialize tool
|
131
|
+
self._interact_tool = MiniGridInteractTool(self.engine)
|
132
|
+
|
133
|
+
async def initialize(self) -> InternalObservation:
|
134
|
+
"""Initialize the environment and return initial observation."""
|
135
|
+
priv, pub = await self.engine._reset_engine()
|
136
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
137
|
+
|
138
|
+
async def step(
|
139
|
+
self,
|
140
|
+
tool_calls: Union[List[Dict[str, Any]], List[EnvToolCall], Dict[str, Any], EnvToolCall],
|
141
|
+
) -> InternalObservation:
|
142
|
+
"""Process a tool call and return observation."""
|
143
|
+
validated_call = self.validate_tool_calls(tool_calls)
|
144
|
+
result = await self._interact_tool(validated_call)
|
145
|
+
|
146
|
+
if result.ok:
|
147
|
+
priv = result.payload["private_state"]
|
148
|
+
pub = result.payload["public_state"]
|
149
|
+
return await self._to_observation(priv, pub, self.custom_step_observation_callable)
|
150
|
+
else:
|
151
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
152
|
+
return await self._to_observation(
|
153
|
+
priv,
|
154
|
+
pub,
|
155
|
+
self.custom_step_observation_callable,
|
156
|
+
extra_obs={"error": result.error},
|
157
|
+
)
|
158
|
+
|
159
|
+
async def terminate(self) -> InternalObservation:
|
160
|
+
"""Terminate the environment and return final observation."""
|
161
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
162
|
+
return await self._to_observation(priv, pub, self.custom_checkpoint_observation_callable)
|
163
|
+
|
164
|
+
async def checkpoint(self) -> InternalObservation:
|
165
|
+
"""Create a checkpoint of the current state."""
|
166
|
+
priv, pub = self.engine.get_current_states_for_observation()
|
167
|
+
return await self._to_observation(priv, pub, self.custom_checkpoint_observation_callable)
|
168
|
+
|
169
|
+
def validate_tool_calls(
|
170
|
+
self,
|
171
|
+
tool_calls: Union[List[Dict[str, Any]], List[EnvToolCall], Dict[str, Any], EnvToolCall],
|
172
|
+
) -> EnvToolCall:
|
173
|
+
"""Validate and normalize tool calls."""
|
174
|
+
# If already an EnvToolCall, validate and return
|
175
|
+
if isinstance(tool_calls, EnvToolCall):
|
176
|
+
if tool_calls.tool != "minigrid_act":
|
177
|
+
raise ValueError(f"Unknown tool: {tool_calls.tool}. Expected 'minigrid_act'.")
|
178
|
+
return tool_calls
|
179
|
+
|
180
|
+
# Handle different input formats
|
181
|
+
if isinstance(tool_calls, dict):
|
182
|
+
# Single tool call
|
183
|
+
tool_call = tool_calls
|
184
|
+
elif isinstance(tool_calls, list) and len(tool_calls) > 0:
|
185
|
+
# List of tool calls - take the first one
|
186
|
+
first_item = tool_calls[0]
|
187
|
+
if isinstance(first_item, list) and len(first_item) > 0:
|
188
|
+
# Nested list
|
189
|
+
tool_call = first_item[0]
|
190
|
+
elif isinstance(first_item, EnvToolCall):
|
191
|
+
# Handle case where service sends list of EnvToolCall objects
|
192
|
+
if first_item.tool != "minigrid_act":
|
193
|
+
raise ValueError(f"Unknown tool: {first_item.tool}. Expected 'minigrid_act'.")
|
194
|
+
return first_item
|
195
|
+
else:
|
196
|
+
tool_call = first_item
|
197
|
+
else:
|
198
|
+
raise ValueError("Invalid tool_calls format")
|
199
|
+
|
200
|
+
# At this point tool_call should be a dict
|
201
|
+
if isinstance(tool_call, EnvToolCall):
|
202
|
+
# Handle case where we somehow still have an EnvToolCall
|
203
|
+
if tool_call.tool != "minigrid_act":
|
204
|
+
raise ValueError(f"Unknown tool: {tool_call.tool}. Expected 'minigrid_act'.")
|
205
|
+
return tool_call
|
206
|
+
|
207
|
+
# Extract tool name and args - fail fast
|
208
|
+
if "tool" in tool_call:
|
209
|
+
tool_name = tool_call["tool"]
|
210
|
+
elif "name" in tool_call:
|
211
|
+
tool_name = tool_call["name"]
|
212
|
+
else:
|
213
|
+
raise ValueError("Tool call missing 'tool' or 'name' field")
|
214
|
+
|
215
|
+
# Handle different argument formats - fail fast
|
216
|
+
if "args" in tool_call:
|
217
|
+
args = tool_call["args"]
|
218
|
+
elif "parameters" in tool_call:
|
219
|
+
args = tool_call["parameters"]
|
220
|
+
elif "input" in tool_call:
|
221
|
+
if isinstance(tool_call["input"], str):
|
222
|
+
args = json.loads(tool_call["input"])
|
223
|
+
else:
|
224
|
+
args = tool_call["input"]
|
225
|
+
else:
|
226
|
+
raise ValueError("Tool call missing 'args', 'parameters', or 'input' field")
|
227
|
+
|
228
|
+
if tool_name != "minigrid_act":
|
229
|
+
raise ValueError(f"Unknown tool: {tool_name}. Expected 'minigrid_act'.")
|
230
|
+
|
231
|
+
# Create EnvToolCall
|
232
|
+
return EnvToolCall(
|
233
|
+
tool=tool_name,
|
234
|
+
args=args,
|
235
|
+
)
|
236
|
+
|
237
|
+
async def _to_observation(
|
238
|
+
self,
|
239
|
+
priv: MiniGridPrivateState,
|
240
|
+
pub: MiniGridPublicState,
|
241
|
+
observation_callable: GetObservationCallable,
|
242
|
+
extra_obs: Optional[Dict[str, Any]] = None,
|
243
|
+
) -> InternalObservation:
|
244
|
+
"""Convert states to observation using callable."""
|
245
|
+
obs = await observation_callable.get_observation(pub, priv)
|
246
|
+
|
247
|
+
# Attach full state objects for downstream analysis (fail fast)
|
248
|
+
obs["public"] = pub
|
249
|
+
obs["private"] = priv
|
250
|
+
|
251
|
+
if extra_obs:
|
252
|
+
obs.update(extra_obs)
|
253
|
+
|
254
|
+
return obs
|
255
|
+
|
256
|
+
async def _serialize_engine(self) -> Dict[str, Any]:
|
257
|
+
"""Serialize the engine state."""
|
258
|
+
snapshot = await self.engine._serialize_engine()
|
259
|
+
return {
|
260
|
+
"task_instance_dict": snapshot.task_instance_dict,
|
261
|
+
"engine_snapshot": snapshot.engine_snapshot,
|
262
|
+
}
|
263
|
+
|
264
|
+
@classmethod
|
265
|
+
async def _deserialize_engine(cls, data: Dict[str, Any]) -> MiniGridEngine:
|
266
|
+
"""Deserialize the engine state."""
|
267
|
+
from synth_ai.environments.examples.minigrid.engine import MiniGridEngineSnapshot
|
268
|
+
|
269
|
+
snapshot = MiniGridEngineSnapshot(
|
270
|
+
task_instance_dict=data["task_instance_dict"],
|
271
|
+
engine_snapshot=data["engine_snapshot"],
|
272
|
+
)
|
273
|
+
|
274
|
+
return await MiniGridEngine._deserialize_engine(snapshot)
|
@@ -0,0 +1,242 @@
|
|
1
|
+
"""
|
2
|
+
MiniGrid Environment Mapping Module
|
3
|
+
|
4
|
+
This module provides functionality to map any integer seed to one of 60 MiniGrid
|
5
|
+
environments using modulo arithmetic for deterministic and reproducible
|
6
|
+
environment selection.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Tuple
|
10
|
+
|
11
|
+
# Environment mapping table (60 total environments)
|
12
|
+
ENVIRONMENT_MAPPING = {
|
13
|
+
# Ultra-Easy (0-4)
|
14
|
+
0: "MiniGrid-Empty-5x5-v0",
|
15
|
+
1: "MiniGrid-Empty-6x6-v0",
|
16
|
+
2: "MiniGrid-Empty-Random-5x5-v0",
|
17
|
+
3: "MiniGrid-Empty-Random-6x6-v0",
|
18
|
+
4: "MiniGrid-GoToDoor-5x5-v0",
|
19
|
+
# Easy (5-14)
|
20
|
+
5: "MiniGrid-Empty-8x8-v0",
|
21
|
+
6: "MiniGrid-FourRooms-v0",
|
22
|
+
7: "MiniGrid-DoorKey-5x5-v0",
|
23
|
+
8: "MiniGrid-GoToDoor-6x6-v0",
|
24
|
+
9: "MiniGrid-GoToDoor-8x8-v0",
|
25
|
+
10: "MiniGrid-Fetch-5x5-N2-v0",
|
26
|
+
11: "MiniGrid-Fetch-6x6-N2-v0",
|
27
|
+
12: "MiniGrid-PutNear-6x6-N2-v0",
|
28
|
+
13: "MiniGrid-Unlock-v0",
|
29
|
+
14: "MiniGrid-UnlockPickup-v0",
|
30
|
+
# Medium (15-29)
|
31
|
+
15: "MiniGrid-DoorKey-6x6-v0",
|
32
|
+
16: "MiniGrid-DoorKey-8x8-v0",
|
33
|
+
17: "MiniGrid-MultiRoom-N2-S4-v0",
|
34
|
+
18: "MiniGrid-LavaGapS5-v0",
|
35
|
+
19: "MiniGrid-LavaGapS6-v0",
|
36
|
+
20: "MiniGrid-LavaGapS7-v0",
|
37
|
+
21: "MiniGrid-SimpleCrossingS9N1-v0",
|
38
|
+
22: "MiniGrid-SimpleCrossingS9N2-v0",
|
39
|
+
23: "MiniGrid-SimpleCrossingS9N3-v0",
|
40
|
+
24: "MiniGrid-Fetch-8x8-N3-v0",
|
41
|
+
25: "MiniGrid-PutNear-8x8-N3-v0",
|
42
|
+
26: "MiniGrid-RedBlueDoors-6x6-v0",
|
43
|
+
27: "MiniGrid-RedBlueDoors-8x8-v0",
|
44
|
+
28: "MiniGrid-BlockedUnlockPickup-v0",
|
45
|
+
29: "MiniGrid-KeyCorridorS3R1-v0",
|
46
|
+
# Hard (30-44)
|
47
|
+
30: "MiniGrid-DoorKey-16x16-v0",
|
48
|
+
31: "MiniGrid-MultiRoom-N4-S5-v0",
|
49
|
+
32: "MiniGrid-MultiRoom-N6-v0",
|
50
|
+
33: "MiniGrid-LavaCrossingS9N1-v0",
|
51
|
+
34: "MiniGrid-LavaCrossingS9N2-v0",
|
52
|
+
35: "MiniGrid-LavaCrossingS9N3-v0",
|
53
|
+
36: "MiniGrid-LavaCrossingS11N5-v0",
|
54
|
+
37: "MiniGrid-SimpleCrossingS11N5-v0",
|
55
|
+
38: "MiniGrid-KeyCorridorS3R2-v0",
|
56
|
+
39: "MiniGrid-KeyCorridorS3R3-v0",
|
57
|
+
40: "MiniGrid-KeyCorridorS4R3-v0",
|
58
|
+
41: "MiniGrid-KeyCorridorS5R3-v0",
|
59
|
+
42: "MiniGrid-KeyCorridorS6R3-v0",
|
60
|
+
43: "MiniGrid-MemoryS7-v0",
|
61
|
+
44: "MiniGrid-MemoryS9-v0",
|
62
|
+
# Ultra-Hard (45-54)
|
63
|
+
45: "MiniGrid-MemoryS11-v0",
|
64
|
+
46: "MiniGrid-MemoryS13-v0",
|
65
|
+
47: "MiniGrid-MemoryS13Random-v0",
|
66
|
+
48: "MiniGrid-MemoryS17Random-v0",
|
67
|
+
49: "MiniGrid-LockedRoom-v0",
|
68
|
+
50: "MiniGrid-ObstructedMaze-1Dlh-v0",
|
69
|
+
51: "MiniGrid-ObstructedMaze-1Dlhb-v0",
|
70
|
+
52: "MiniGrid-ObstructedMaze-2Dlhb-v0",
|
71
|
+
53: "MiniGrid-ObstructedMaze-Full-v0",
|
72
|
+
54: "MiniGrid-DistShift1-v0",
|
73
|
+
# Specialized (55-59)
|
74
|
+
55: "MiniGrid-DistShift2-v0",
|
75
|
+
56: "MiniGrid-Dynamic-Obstacles-8x8-v0",
|
76
|
+
57: "MiniGrid-Dynamic-Obstacles-16x16-v0",
|
77
|
+
58: "MiniGrid-Playground-v0",
|
78
|
+
59: "MiniGrid-Empty-16x16-v0",
|
79
|
+
}
|
80
|
+
|
81
|
+
# Difficulty mapping
|
82
|
+
DIFFICULTY_MAPPING = {
|
83
|
+
"ultra-easy": (0, 4),
|
84
|
+
"easy": (5, 14),
|
85
|
+
"medium": (15, 29),
|
86
|
+
"hard": (30, 44),
|
87
|
+
"ultra-hard": (45, 54),
|
88
|
+
"specialized": (55, 59),
|
89
|
+
}
|
90
|
+
|
91
|
+
|
92
|
+
def get_environment_from_seed(seed: int, hash_seed: bool = True) -> str:
|
93
|
+
"""
|
94
|
+
Map any integer seed to a MiniGrid environment name.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
seed: Integer seed
|
98
|
+
hash_seed: If True, hash the seed for better distribution
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
Environment name string
|
102
|
+
"""
|
103
|
+
if hash_seed:
|
104
|
+
# Use hash for better distribution of sequential seeds
|
105
|
+
env_index = hash(seed) % 60
|
106
|
+
else:
|
107
|
+
# Simple modulo
|
108
|
+
env_index = seed % 60
|
109
|
+
|
110
|
+
return ENVIRONMENT_MAPPING[env_index]
|
111
|
+
|
112
|
+
|
113
|
+
def get_difficulty_from_seed(seed: int, hash_seed: bool = True) -> str:
|
114
|
+
"""
|
115
|
+
Get difficulty level for a given seed.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
seed: Integer seed
|
119
|
+
hash_seed: If True, hash the seed for better distribution
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
Difficulty level string
|
123
|
+
"""
|
124
|
+
if hash_seed:
|
125
|
+
env_index = hash(seed) % 60
|
126
|
+
else:
|
127
|
+
env_index = seed % 60
|
128
|
+
|
129
|
+
if env_index <= 4:
|
130
|
+
return "ultra-easy"
|
131
|
+
elif env_index <= 14:
|
132
|
+
return "easy"
|
133
|
+
elif env_index <= 29:
|
134
|
+
return "medium"
|
135
|
+
elif env_index <= 44:
|
136
|
+
return "hard"
|
137
|
+
elif env_index <= 54:
|
138
|
+
return "ultra-hard"
|
139
|
+
else:
|
140
|
+
return "specialized"
|
141
|
+
|
142
|
+
|
143
|
+
def get_minigrid_environment(seed: int, hash_seed: bool = True) -> Tuple[str, str]:
|
144
|
+
"""
|
145
|
+
Get MiniGrid environment name and difficulty from seed.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
seed: Integer seed
|
149
|
+
hash_seed: If True, hash the seed for better distribution
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
(environment_name, difficulty_level)
|
153
|
+
"""
|
154
|
+
env_name = get_environment_from_seed(seed, hash_seed)
|
155
|
+
difficulty = get_difficulty_from_seed(seed, hash_seed)
|
156
|
+
|
157
|
+
return env_name, difficulty
|
158
|
+
|
159
|
+
|
160
|
+
def get_environment_by_difficulty(difficulty: str, seed: int = 0) -> str:
|
161
|
+
"""
|
162
|
+
Get a random environment from a specific difficulty level.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
difficulty: Difficulty level ("ultra-easy", "easy", "medium", "hard", "ultra-hard", "specialized")
|
166
|
+
seed: Seed for selecting within the difficulty range
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
Environment name string
|
170
|
+
"""
|
171
|
+
if difficulty not in DIFFICULTY_MAPPING:
|
172
|
+
raise ValueError(f"Unknown difficulty: {difficulty}")
|
173
|
+
|
174
|
+
start, end = DIFFICULTY_MAPPING[difficulty]
|
175
|
+
range_size = end - start + 1
|
176
|
+
env_index = start + (seed % range_size)
|
177
|
+
|
178
|
+
return ENVIRONMENT_MAPPING[env_index]
|
179
|
+
|
180
|
+
|
181
|
+
def get_curriculum_environment(progress: float, seed: int = 0) -> Tuple[str, str]:
|
182
|
+
"""
|
183
|
+
Select environment based on curriculum progress.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
progress: Progress value from 0.0 to 1.0
|
187
|
+
seed: Seed for environment selection within difficulty
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
(environment_name, difficulty_level)
|
191
|
+
"""
|
192
|
+
if progress < 0.2: # Early stage - ultra-easy
|
193
|
+
difficulty = "ultra-easy"
|
194
|
+
elif progress < 0.4: # Beginning - easy
|
195
|
+
difficulty = "easy"
|
196
|
+
elif progress < 0.6: # Intermediate - medium
|
197
|
+
difficulty = "medium"
|
198
|
+
elif progress < 0.8: # Advanced - hard
|
199
|
+
difficulty = "hard"
|
200
|
+
else: # Expert - ultra-hard/specialized
|
201
|
+
if progress < 0.9:
|
202
|
+
difficulty = "ultra-hard"
|
203
|
+
else:
|
204
|
+
difficulty = "specialized"
|
205
|
+
|
206
|
+
env_name = get_environment_by_difficulty(difficulty, seed)
|
207
|
+
return env_name, difficulty
|
208
|
+
|
209
|
+
|
210
|
+
def validate_environment_name(env_name: str) -> bool:
|
211
|
+
"""
|
212
|
+
Check if an environment name is supported.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
env_name: Environment name to validate
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
True if supported, False otherwise
|
219
|
+
"""
|
220
|
+
return env_name in ENVIRONMENT_MAPPING.values()
|
221
|
+
|
222
|
+
|
223
|
+
def get_all_environments() -> list[str]:
|
224
|
+
"""Get list of all supported environment names."""
|
225
|
+
return list(ENVIRONMENT_MAPPING.values())
|
226
|
+
|
227
|
+
|
228
|
+
def get_environments_by_difficulty(difficulty: str) -> list[str]:
|
229
|
+
"""
|
230
|
+
Get all environments for a specific difficulty level.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
difficulty: Difficulty level
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
List of environment names
|
237
|
+
"""
|
238
|
+
if difficulty not in DIFFICULTY_MAPPING:
|
239
|
+
raise ValueError(f"Unknown difficulty: {difficulty}")
|
240
|
+
|
241
|
+
start, end = DIFFICULTY_MAPPING[difficulty]
|
242
|
+
return [ENVIRONMENT_MAPPING[i] for i in range(start, end + 1)]
|