miniworld-maze 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of miniworld-maze might be problematic. Click here for more details.

@@ -15,17 +15,29 @@ Main modules:
15
15
  - tools: Observation generation and utilities
16
16
  """
17
17
 
18
+ import os
19
+ import warnings
20
+
21
+ # Set PYGLET_HEADLESS=1 by default if not already set
22
+ if "PYGLET_HEADLESS" not in os.environ:
23
+ os.environ["PYGLET_HEADLESS"] = "1"
24
+ warnings.warn(
25
+ "Automatically set PYGLET_HEADLESS=1 for headless rendering. "
26
+ "Set PYGLET_HEADLESS=0 before importing miniworld_maze to override this behavior.",
27
+ UserWarning,
28
+ stacklevel=2
29
+ )
30
+
18
31
  from .core import ObservationLevel
19
- from .environments.factory import (
20
- NineRoomsEnvironmentWrapper,
21
- )
22
32
  from .environments.nine_rooms import NineRooms
23
33
  from .environments.spiral_nine_rooms import SpiralNineRooms
24
34
  from .environments.twenty_five_rooms import TwentyFiveRooms
25
35
 
26
- __version__ = "1.0.0"
36
+ # Import factory to register environments
37
+ from .environments import factory # noqa: F401
38
+
39
+ __version__ = "1.1.0"
27
40
  __all__ = [
28
- "NineRoomsEnvironmentWrapper",
29
41
  "NineRooms",
30
42
  "SpiralNineRooms",
31
43
  "TwentyFiveRooms",
@@ -57,23 +57,64 @@ DEFAULT_BENCHMARK_STEPS: Final[int] = 100
57
57
  DEFAULT_WARMUP_STEPS: Final[int] = 10
58
58
 
59
59
  # ========================
60
- # OBSERVATION TEST POSITIONS
60
+ # TEXTURE THEMES
61
61
  # ========================
62
62
 
63
+ class TextureThemes:
64
+ """Pre-defined texture themes for different environments."""
65
+
66
+ NINE_ROOMS = [
67
+ "beige",
68
+ "lightbeige",
69
+ "lightgray",
70
+ "copperred",
71
+ "skyblue",
72
+ "lightcobaltgreen",
73
+ "oakbrown",
74
+ "navyblue",
75
+ "cobaltgreen",
76
+ ]
77
+
78
+ SPIRAL_NINE_ROOMS = [
79
+ "beige",
80
+ "lightbeige",
81
+ "lightgray",
82
+ "copperred",
83
+ "skyblue",
84
+ "lightcobaltgreen",
85
+ "oakbrown",
86
+ "navyblue",
87
+ "cobaltgreen",
88
+ ]
89
+
90
+ TWENTY_FIVE_ROOMS = [
91
+ "crimson",
92
+ "beanpaste",
93
+ "cobaltgreen",
94
+ "lightnavyblue",
95
+ "skyblue",
96
+ "lightcobaltgreen",
97
+ "oakbrown",
98
+ "copperred",
99
+ "lightgray",
100
+ "lime",
101
+ "turquoise",
102
+ "violet",
103
+ "beige",
104
+ "morningglory",
105
+ "silver",
106
+ "magenta",
107
+ "sunnyyellow",
108
+ "blueberry",
109
+ "lightbeige",
110
+ "seablue",
111
+ "lemongrass",
112
+ "orchid",
113
+ "redbean",
114
+ "orange",
115
+ "realblueberry",
116
+ ]
63
117
 
64
- # Convenient position calculations for standard room layouts
65
- class RoomPositions:
66
- """Pre-calculated positions for standard room layouts."""
67
-
68
- # NineRooms (3x3) strategic positions
69
- NINE_ROOMS_CENTER: Final[tuple[float, float, float]] = (22.5, 0.0, 22.5)
70
- NINE_ROOMS_TOP_LEFT: Final[tuple[float, float, float]] = (7.5, 0.0, 7.5)
71
- NINE_ROOMS_BOTTOM_RIGHT: Final[tuple[float, float, float]] = (37.5, 0.0, 37.5)
72
-
73
- # TwentyFiveRooms (5x5) strategic positions
74
- TWENTY_FIVE_CENTER: Final[tuple[float, float, float]] = (75.0, 0.0, 75.0)
75
- TWENTY_FIVE_CORNER: Final[tuple[float, float, float]] = (37.5, 0.0, 37.5)
76
- TWENTY_FIVE_FAR_CORNER: Final[tuple[float, float, float]] = (112.5, 0.0, 112.5)
77
118
 
78
119
 
79
120
  # ========================
@@ -1,4 +1,4 @@
1
1
  # Core module cleaned up - imports removed since files were consolidated
2
2
 
3
3
  # Import the envs module so that envs register themselves
4
- from . import envs
4
+ from . import envs as envs
@@ -21,6 +21,7 @@ class MiniWorldEnv(UnifiedMiniWorldEnv):
21
21
  window_height=600,
22
22
  params=DEFAULT_PARAMS,
23
23
  domain_rand=False,
24
+ render_mode=None,
24
25
  ):
25
26
  """
26
27
  Initialize base MiniWorld environment.
@@ -33,6 +34,7 @@ class MiniWorldEnv(UnifiedMiniWorldEnv):
33
34
  window_height: Window height for human rendering
34
35
  params: Environment parameters for domain randomization
35
36
  domain_rand: Whether to enable domain randomization
37
+ render_mode: Render mode ("human", "rgb_array", or None)
36
38
  """
37
39
  # Mark this as a base environment (not custom) for background color handling
38
40
  self._is_custom_env = False
@@ -49,4 +51,5 @@ class MiniWorldEnv(UnifiedMiniWorldEnv):
49
51
  window_height=window_height,
50
52
  params=params,
51
53
  domain_rand=domain_rand,
54
+ render_mode=render_mode,
52
55
  )
@@ -63,7 +63,7 @@ class UnifiedMiniWorldEnv(gym.Env):
63
63
  both the enhanced features of CustomMiniWorldEnv and the legacy BaseEnv functionality.
64
64
  """
65
65
 
66
- metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 30}
66
+ metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 30}
67
67
 
68
68
  # Enumeration of possible actions
69
69
  class Actions(IntEnum):
@@ -98,6 +98,7 @@ class UnifiedMiniWorldEnv(gym.Env):
98
98
  params=DEFAULT_PARAMS,
99
99
  domain_rand=False,
100
100
  info_obs: Optional[List[ObservationLevel]] = None,
101
+ render_mode=None,
101
102
  ):
102
103
  """
103
104
  Initialize unified MiniWorld environment.
@@ -114,6 +115,7 @@ class UnifiedMiniWorldEnv(gym.Env):
114
115
  params: Environment parameters for domain randomization
115
116
  domain_rand: Whether to enable domain randomization
116
117
  info_obs: List of observation levels to include in info dictionary
118
+ render_mode: Render mode ("human", "rgb_array", or None)
117
119
  """
118
120
  # Store configuration
119
121
  self.obs_level = obs_level
@@ -123,6 +125,13 @@ class UnifiedMiniWorldEnv(gym.Env):
123
125
  self.params = params
124
126
  self.domain_rand = domain_rand
125
127
  self.info_obs = info_obs
128
+ self.render_mode = render_mode
129
+
130
+ # Validate render_mode
131
+ if render_mode is not None and render_mode not in self.metadata["render_modes"]:
132
+ raise ValueError(
133
+ f"render_mode must be one of {self.metadata['render_modes']}, got {render_mode}"
134
+ )
126
135
 
127
136
  # Setup action space
128
137
  self._setup_action_space()
@@ -332,34 +341,32 @@ class UnifiedMiniWorldEnv(gym.Env):
332
341
  self._render_static()
333
342
 
334
343
  # Generate the first camera image
335
- obs = self._generate_observation()
344
+ obs = self._generate_observation(self.obs_level)
336
345
 
337
346
  # Generate additional observations for info dictionary if specified
338
347
  info = {}
339
348
  if self.info_obs is not None:
340
349
  for obs_level in self.info_obs:
341
- # Temporarily change obs_level to generate the desired observation
342
- original_obs_level = self.obs_level
343
- self.obs_level = obs_level
344
- info_obs = self._generate_observation()
345
- self.obs_level = original_obs_level
350
+ # Generate observation with the specified level
351
+ info_obs = self._generate_observation(observation_level=obs_level)
346
352
  # Use the observation level name as key
347
353
  info[str(obs_level)] = info_obs
348
354
 
349
355
  # Return first observation with info dict for Gymnasium compatibility
350
356
  return obs, info
351
357
 
352
- def _generate_observation(self, render_agent: bool = None):
353
- """Generate observation based on current observation level.
358
+ def _generate_observation(self, observation_level, render_agent: bool = None):
359
+ """Generate observation based on specified observation level.
354
360
 
355
361
  Args:
362
+ observation_level: Observation level to use.
356
363
  render_agent: Whether to render the agent in the observation.
357
364
  If None, uses default behavior based on observation level.
358
365
  """
359
366
  # Import ObservationLevel here to avoid circular imports
360
367
  from ..observation_types import ObservationLevel
361
368
 
362
- if self.obs_level == ObservationLevel.TOP_DOWN_PARTIAL:
369
+ if observation_level == ObservationLevel.TOP_DOWN_PARTIAL:
363
370
  if self.agent_mode == "empty":
364
371
  # Agent mode 'empty' always renders without agent
365
372
  render_ag = False
@@ -371,19 +378,19 @@ class UnifiedMiniWorldEnv(gym.Env):
371
378
  render_ag = True
372
379
  return self.render_top_view(POMDP=True, render_ag=render_ag)
373
380
 
374
- elif self.obs_level == ObservationLevel.TOP_DOWN_FULL:
381
+ elif observation_level == ObservationLevel.TOP_DOWN_FULL:
375
382
  # Use explicit render_agent parameter or default to True
376
383
  render_ag = render_agent if render_agent is not None else True
377
384
  return self.render_top_view(POMDP=False, render_ag=render_ag)
378
385
 
379
- elif self.obs_level == ObservationLevel.FIRST_PERSON:
386
+ elif observation_level == ObservationLevel.FIRST_PERSON:
380
387
  # First person view doesn't include the agent anyway
381
388
  return self.render_obs()
382
389
 
383
390
  else:
384
391
  valid_levels = list(ObservationLevel)
385
392
  raise ValueError(
386
- f"Invalid obs_level {self.obs_level}. Must be one of {valid_levels}"
393
+ f"Invalid obs_level {observation_level}. Must be one of {valid_levels}"
387
394
  )
388
395
 
389
396
  def _calculate_carried_object_position(self, agent_pos, ent):
@@ -507,7 +514,7 @@ class UnifiedMiniWorldEnv(gym.Env):
507
514
  self._process_action(action)
508
515
 
509
516
  # Generate observation
510
- observation = self._generate_observation()
517
+ observation = self._generate_observation(self.obs_level)
511
518
 
512
519
  # Calculate step results
513
520
  reward, terminated, info = self._calculate_step_results(observation)
@@ -586,11 +593,8 @@ class UnifiedMiniWorldEnv(gym.Env):
586
593
  info = {}
587
594
  if self.info_obs is not None:
588
595
  for obs_level in self.info_obs:
589
- # Temporarily change obs_level to generate the desired observation
590
- original_obs_level = self.obs_level
591
- self.obs_level = obs_level
592
- info_obs = self._generate_observation()
593
- self.obs_level = original_obs_level
596
+ # Generate observation with the specified level
597
+ info_obs = self._generate_observation(observation_level=obs_level)
594
598
  # Use the observation level name as key
595
599
  info[str(obs_level)] = info_obs
596
600
 
@@ -1,14 +1,12 @@
1
1
  """Nine Rooms environment implementations."""
2
2
 
3
3
  from .base_grid_rooms import GridRoomsEnvironment
4
- from .factory import NineRoomsEnvironmentWrapper
5
4
  from .nine_rooms import NineRooms
6
5
  from .spiral_nine_rooms import SpiralNineRooms
7
6
  from .twenty_five_rooms import TwentyFiveRooms
8
7
 
9
8
  __all__ = [
10
9
  "GridRoomsEnvironment",
11
- "NineRoomsEnvironmentWrapper",
12
10
  "NineRooms",
13
11
  "SpiralNineRooms",
14
12
  "TwentyFiveRooms",
@@ -46,6 +46,7 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
46
46
  agent_mode: Optional[str] = None,
47
47
  obs_width: int = DEFAULT_OBS_WIDTH,
48
48
  obs_height: int = DEFAULT_OBS_HEIGHT,
49
+ render_mode=None,
49
50
  **kwargs,
50
51
  ):
51
52
  """
@@ -64,6 +65,7 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
64
65
  agent_mode: Agent rendering mode ('triangle', 'circle', 'empty')
65
66
  obs_width: Observation width in pixels (defaults to DEFAULT_OBS_WIDTH)
66
67
  obs_height: Observation height in pixels (defaults to DEFAULT_OBS_HEIGHT)
68
+ render_mode: Render mode ("human", "rgb_array", or None)
67
69
  **kwargs: Additional arguments passed to parent class
68
70
  """
69
71
 
@@ -77,7 +79,8 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
77
79
 
78
80
  # Validate and set textures
79
81
  assert len(textures) == self.total_rooms, (
80
- f"Textures for floor should be same as the number of the rooms ({self.total_rooms})"
82
+ f"Textures for floor should be same as the number of the rooms "
83
+ f"({self.total_rooms})"
81
84
  )
82
85
  self.textures = textures
83
86
 
@@ -119,6 +122,7 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
119
122
  agent_mode=self.agent_mode,
120
123
  obs_width=obs_width,
121
124
  obs_height=obs_height,
125
+ render_mode=render_mode,
122
126
  **kwargs,
123
127
  )
124
128
 
@@ -226,16 +230,20 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
226
230
  obs, reward, terminated, truncated, info = super().step(action)
227
231
 
228
232
  # Check if goal is achieved
229
- if self.is_goal_achieved():
233
+ if self._is_goal_achieved():
230
234
  terminated = True
231
235
  reward = 1.0 # Positive reward for achieving goal
232
236
 
237
+ # Add agent and goal positions to info dictionary
238
+ agent_pos = self.agent.pos
239
+ info["agent_position"] = np.array([agent_pos[0], agent_pos[2]]) # x, z
240
+
241
+ if hasattr(self, "_current_goal_position"):
242
+ goal_pos = self._current_goal_position
243
+ info["goal_position"] = np.array([goal_pos[0], goal_pos[2]]) # x, z
244
+
233
245
  # Return observation as dict
234
- obs_dict = {
235
- "observation": obs,
236
- "desired_goal": self.desired_goal,
237
- "achieved_goal": obs,
238
- }
246
+ obs_dict = self._build_observation_dict(obs)
239
247
  return obs_dict, reward, terminated, truncated, info
240
248
 
241
249
  def reset(self, seed=None, options=None, pos=None):
@@ -254,17 +262,21 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
254
262
  obs, info = super().reset(seed=seed, options=options, pos=pos)
255
263
 
256
264
  # Generate goal
257
- self.desired_goal = self.get_goal()
265
+ self.desired_goal = self._get_goal()
266
+
267
+ # Add agent and goal positions to info dictionary
268
+ agent_pos = self.agent.pos
269
+ info["agent_position"] = np.array([agent_pos[0], agent_pos[2]]) # x, z
270
+
271
+ if hasattr(self, "_current_goal_position"):
272
+ goal_pos = self._current_goal_position
273
+ info["goal_position"] = np.array([goal_pos[0], goal_pos[2]]) # x, z
258
274
 
259
275
  # Return observation as dict with desired_goal and achieved_goal
260
- obs_dict = {
261
- "observation": obs,
262
- "desired_goal": self.desired_goal,
263
- "achieved_goal": obs,
264
- }
276
+ obs_dict = self._build_observation_dict(obs)
265
277
  return obs_dict, info
266
278
 
267
- def get_goal(self):
279
+ def _get_goal(self):
268
280
  """
269
281
  Generate a goal by randomly selecting a room and goal position.
270
282
 
@@ -318,7 +330,7 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
318
330
 
319
331
  return obs
320
332
 
321
- def is_goal_achieved(self, pos=None, threshold=0.5):
333
+ def _is_goal_achieved(self, pos=None, threshold=0.5):
322
334
  """
323
335
  Check if the agent has achieved the current goal.
324
336
 
@@ -341,3 +353,66 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
341
353
  distance = np.linalg.norm(pos_array - goal_array)
342
354
 
343
355
  return bool(distance < threshold)
356
+
357
+ @staticmethod
358
+ def _generate_goal_positions(
359
+ grid_size: int, room_size: Union[int, float], goals_per_room: int = 2
360
+ ) -> List[List[List[float]]]:
361
+ """
362
+ Generate goal positions for grid layout.
363
+ Args:
364
+ grid_size: Size of the grid (e.g., 3 for 3x3, 5 for 5x5)
365
+ room_size: Size of each room
366
+ goals_per_room: Number of goals per room (1 or 2)
367
+ Returns:
368
+ List of goal positions for each room
369
+ """
370
+ goal_positions = []
371
+ for i in range(grid_size): # rows
372
+ for j in range(grid_size): # columns
373
+ center_x = room_size * j + room_size / 2
374
+ center_z = room_size * i + room_size / 2
375
+ if goals_per_room == 1:
376
+ # One goal per room at the center
377
+ goal_positions.append([[center_x, 0.0, center_z]])
378
+ else:
379
+ # Two goals per room: center-left and center-right
380
+ goal_positions.append([
381
+ [center_x - 1.0, 0.0, center_z], # left goal
382
+ [center_x + 1.0, 0.0, center_z], # right goal
383
+ ])
384
+ return goal_positions
385
+
386
+ def get_extent(self, padding: float = 1.0) -> Tuple[float, float, float, float]:
387
+ """
388
+ Get the scene extent for use with matplotlib imshow.
389
+
390
+ Returns the scene bounds with padding in the format expected by
391
+ matplotlib's imshow(extent=...) parameter: (left, right, bottom, top).
392
+
393
+ Args:
394
+ padding: Padding to add around environment bounds (default: 1.0)
395
+
396
+ Returns:
397
+ Tuple[float, float, float, float]: (min_x, max_x, min_z, max_z) with padding
398
+ """
399
+ return (
400
+ self.min_x - padding,
401
+ self.max_x + padding,
402
+ self.min_z - padding,
403
+ self.max_z + padding
404
+ )
405
+
406
+ def _build_observation_dict(self, obs: np.ndarray) -> dict:
407
+ """
408
+ Build the standard observation dictionary format.
409
+ Args:
410
+ obs: The observation array
411
+ Returns:
412
+ Dictionary with observation, desired_goal, and achieved_goal
413
+ """
414
+ return {
415
+ "observation": obs,
416
+ "desired_goal": self.desired_goal,
417
+ "achieved_goal": obs,
418
+ }
@@ -1,105 +1,42 @@
1
- """Factory for creating Nine Rooms environment variants."""
2
-
3
- from typing import List
4
-
5
- import gymnasium as gym
6
- import numpy as np
1
+ """Gymnasium environment registrations for Nine Rooms environment variants."""
7
2
 
3
+ from gymnasium.envs.registration import register
8
4
  from ..core import ObservationLevel
9
5
  from ..core.constants import FACTORY_DOOR_SIZE, FACTORY_ROOM_SIZE
10
- from .nine_rooms import NineRooms
11
- from .spiral_nine_rooms import SpiralNineRooms
12
- from .twenty_five_rooms import TwentyFiveRooms
13
-
14
-
15
- class NineRoomsEnvironmentWrapper(gym.Wrapper):
16
- """Unified wrapper for all Nine Rooms environment variants."""
17
-
18
- def __init__(
19
- self,
20
- variant="NineRooms",
21
- obs_level=ObservationLevel.TOP_DOWN_PARTIAL,
22
- continuous=False,
23
- size=64,
24
- room_size=FACTORY_ROOM_SIZE,
25
- door_size=FACTORY_DOOR_SIZE,
26
- agent_mode=None,
27
- info_obs: List[ObservationLevel] = None,
28
- ):
29
- """
30
- Create a Nine Rooms environment variant.
31
-
32
- Args:
33
- variant: Environment variant ("NineRooms", "SpiralNineRooms", "TwentyFiveRooms")
34
- obs_level: Observation level (ObservationLevel enum)
35
- continuous: Whether to use continuous actions
36
- size: Observation image size (rendered directly at this size to avoid resizing)
37
- room_size: Size of each room in environment units
38
- door_size: Size of doors between rooms
39
- agent_mode: Agent rendering mode ('empty', 'circle', 'triangle', or None for default)
40
- info_obs: List of observation levels to include in info dictionary
41
- """
42
- self.variant = variant
43
-
44
- # Select the appropriate environment class
45
- env_classes = {
46
- "NineRooms": NineRooms,
47
- "SpiralNineRooms": SpiralNineRooms,
48
- "TwentyFiveRooms": TwentyFiveRooms,
49
- }
50
-
51
- if variant not in env_classes:
52
- raise ValueError(
53
- f"Unknown variant '{variant}'. Available: {list(env_classes.keys())}"
54
- )
55
-
56
- env_class = env_classes[variant]
57
-
58
- # Create base environment with direct rendering size
59
- base_env = env_class(
60
- room_size=room_size,
61
- door_size=door_size,
62
- obs_level=obs_level,
63
- continuous=continuous,
64
- obs_width=size,
65
- obs_height=size,
66
- agent_mode=agent_mode,
67
- info_obs=info_obs,
68
- )
69
-
70
- # Apply wrappers - no resize needed since we render at target size
71
-
72
- # Initialize gym.Wrapper with the base environment
73
- super().__init__(base_env)
74
-
75
- def render_on_pos(self, pos):
76
- """Render observation from a specific position."""
77
- # Get access to the base environment
78
- base_env = self.env
79
- while hasattr(base_env, "env") or hasattr(base_env, "_env"):
80
- if hasattr(base_env, "env"):
81
- base_env = base_env.env
82
- elif hasattr(base_env, "_env"):
83
- base_env = base_env._env
84
- else:
85
- break
86
-
87
- # Store original position
88
- original_pos = base_env.agent.pos.copy()
89
-
90
- # Move agent to target position
91
- base_env.place_agent(pos=pos)
92
-
93
- # Get first-person observation from the agent's perspective at this position
94
- obs = base_env.render_obs()
95
-
96
- # Restore original position
97
- base_env.place_agent(pos=original_pos)
98
-
99
- # Apply wrapper transformations manually for consistency
100
- # Convert to PyTorch format (CHW) - no resize needed since we render at target size
101
- obs = np.transpose(obs, (2, 0, 1))
102
-
103
- return obs
104
-
105
6
 
7
+ # Register environment variants with factory defaults matching the original wrapper
8
+ register(
9
+ id="NineRooms-v0",
10
+ entry_point="miniworld_maze.environments.nine_rooms:NineRooms",
11
+ max_episode_steps=1000,
12
+ kwargs={
13
+ "room_size": FACTORY_ROOM_SIZE,
14
+ "door_size": FACTORY_DOOR_SIZE,
15
+ "obs_level": ObservationLevel.TOP_DOWN_PARTIAL,
16
+ "agent_mode": None, # becomes "empty" by default
17
+ },
18
+ )
19
+
20
+ register(
21
+ id="SpiralNineRooms-v0",
22
+ entry_point="miniworld_maze.environments.spiral_nine_rooms:SpiralNineRooms",
23
+ max_episode_steps=1000,
24
+ kwargs={
25
+ "room_size": FACTORY_ROOM_SIZE,
26
+ "door_size": FACTORY_DOOR_SIZE,
27
+ "obs_level": ObservationLevel.TOP_DOWN_PARTIAL,
28
+ "agent_mode": None,
29
+ },
30
+ )
31
+
32
+ register(
33
+ id="TwentyFiveRooms-v0",
34
+ entry_point="miniworld_maze.environments.twenty_five_rooms:TwentyFiveRooms",
35
+ max_episode_steps=1000,
36
+ kwargs={
37
+ "room_size": FACTORY_ROOM_SIZE,
38
+ "door_size": FACTORY_DOOR_SIZE,
39
+ "obs_level": ObservationLevel.TOP_DOWN_PARTIAL,
40
+ "agent_mode": None,
41
+ },
42
+ )
@@ -1,6 +1,7 @@
1
1
  """NineRooms environment implementation."""
2
2
 
3
3
  from ..core import ObservationLevel
4
+ from ..core.constants import TextureThemes
4
5
  from .base_grid_rooms import GridRoomsEnvironment
5
6
 
6
7
 
@@ -46,31 +47,12 @@ class NineRooms(GridRoomsEnvironment):
46
47
  (6, 7),
47
48
  (7, 8),
48
49
  ]
49
- default_textures = [
50
- "beige",
51
- "lightbeige",
52
- "lightgray",
53
- "copperred",
54
- "skyblue",
55
- "lightcobaltgreen",
56
- "oakbrown",
57
- "navyblue",
58
- "cobaltgreen",
59
- ]
50
+ default_textures = TextureThemes.NINE_ROOMS
60
51
 
61
52
  # Initialize goal positions for each room (2 goals per room)
62
- goal_positions = []
63
- for i in range(3): # rows
64
- for j in range(3): # columns
65
- center_x = room_size * j + room_size / 2
66
- center_z = room_size * i + room_size / 2
67
- # Two goals per room: center-left and center-right
68
- goal_positions.append(
69
- [
70
- [center_x - 1.0, 0.0, center_z], # left goal
71
- [center_x + 1.0, 0.0, center_z], # right goal
72
- ]
73
- )
53
+ goal_positions = GridRoomsEnvironment._generate_goal_positions(
54
+ 3, room_size, goals_per_room=2
55
+ )
74
56
 
75
57
  super().__init__(
76
58
  grid_size=3,
@@ -1,6 +1,7 @@
1
1
  """SpiralNineRooms environment implementation."""
2
2
 
3
3
  from ..core import ObservationLevel
4
+ from ..core.constants import TextureThemes
4
5
  from .base_grid_rooms import GridRoomsEnvironment
5
6
 
6
7
 
@@ -42,31 +43,12 @@ class SpiralNineRooms(GridRoomsEnvironment):
42
43
  (6, 7),
43
44
  (7, 8),
44
45
  ]
45
- default_textures = [
46
- "beige",
47
- "lightbeige",
48
- "lightgray",
49
- "copperred",
50
- "skyblue",
51
- "lightcobaltgreen",
52
- "oakbrown",
53
- "navyblue",
54
- "cobaltgreen",
55
- ]
46
+ default_textures = TextureThemes.SPIRAL_NINE_ROOMS
56
47
 
57
48
  # Initialize goal positions for each room (2 goals per room)
58
- goal_positions = []
59
- for i in range(3): # rows
60
- for j in range(3): # columns
61
- center_x = room_size * j + room_size / 2
62
- center_z = room_size * i + room_size / 2
63
- # Two goals per room: center-left and center-right
64
- goal_positions.append(
65
- [
66
- [center_x - 1.0, 0.0, center_z], # left goal
67
- [center_x + 1.0, 0.0, center_z], # right goal
68
- ]
69
- )
49
+ goal_positions = GridRoomsEnvironment._generate_goal_positions(
50
+ 3, room_size, goals_per_room=2
51
+ )
70
52
 
71
53
  super().__init__(
72
54
  grid_size=3,