miniworld-maze 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of miniworld-maze might be problematic. Click here for more details.
- miniworld_maze/__init__.py +17 -9
- miniworld_maze/core/constants.py +55 -14
- miniworld_maze/core/miniworld_gymnasium/__init__.py +1 -1
- miniworld_maze/core/miniworld_gymnasium/unified_env.py +45 -29
- miniworld_maze/environments/__init__.py +0 -3
- miniworld_maze/environments/base_grid_rooms.py +213 -2
- miniworld_maze/environments/factory.py +38 -151
- miniworld_maze/environments/nine_rooms.py +8 -11
- miniworld_maze/environments/spiral_nine_rooms.py +8 -11
- miniworld_maze/environments/twenty_five_rooms.py +8 -27
- miniworld_maze/tools/__init__.py +1 -3
- miniworld_maze/utils.py +286 -0
- miniworld_maze-1.2.0.dist-info/METADATA +261 -0
- {miniworld_maze-1.0.0.dist-info → miniworld_maze-1.2.0.dist-info}/RECORD +15 -18
- {miniworld_maze-1.0.0.dist-info → miniworld_maze-1.2.0.dist-info}/WHEEL +1 -1
- miniworld_maze/tools/generate_observations.py +0 -199
- miniworld_maze/wrappers/__init__.py +0 -5
- miniworld_maze/wrappers/image_transforms.py +0 -40
- miniworld_maze-1.0.0.dist-info/METADATA +0 -108
- miniworld_maze-1.0.0.dist-info/entry_points.txt +0 -3
miniworld_maze/__init__.py
CHANGED
|
@@ -15,21 +15,29 @@ Main modules:
|
|
|
15
15
|
- tools: Observation generation and utilities
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
+
import os
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
# Set PYGLET_HEADLESS=1 by default if not already set
|
|
22
|
+
if "PYGLET_HEADLESS" not in os.environ:
|
|
23
|
+
os.environ["PYGLET_HEADLESS"] = "1"
|
|
24
|
+
warnings.warn(
|
|
25
|
+
"Automatically set PYGLET_HEADLESS=1 for headless rendering. "
|
|
26
|
+
"Set PYGLET_HEADLESS=0 before importing miniworld_maze to override this behavior.",
|
|
27
|
+
UserWarning,
|
|
28
|
+
stacklevel=2
|
|
29
|
+
)
|
|
30
|
+
|
|
18
31
|
from .core import ObservationLevel
|
|
19
|
-
from .environments.factory import (
|
|
20
|
-
NineRoomsEnvironmentWrapper,
|
|
21
|
-
create_drstrategy_env,
|
|
22
|
-
create_nine_rooms_env,
|
|
23
|
-
)
|
|
24
32
|
from .environments.nine_rooms import NineRooms
|
|
25
33
|
from .environments.spiral_nine_rooms import SpiralNineRooms
|
|
26
34
|
from .environments.twenty_five_rooms import TwentyFiveRooms
|
|
27
35
|
|
|
28
|
-
|
|
36
|
+
# Import factory to register environments
|
|
37
|
+
from .environments import factory # noqa: F401
|
|
38
|
+
|
|
39
|
+
__version__ = "1.1.0"
|
|
29
40
|
__all__ = [
|
|
30
|
-
"create_drstrategy_env",
|
|
31
|
-
"create_nine_rooms_env", # deprecated but kept for backward compatibility
|
|
32
|
-
"NineRoomsEnvironmentWrapper",
|
|
33
41
|
"NineRooms",
|
|
34
42
|
"SpiralNineRooms",
|
|
35
43
|
"TwentyFiveRooms",
|
miniworld_maze/core/constants.py
CHANGED
|
@@ -57,23 +57,64 @@ DEFAULT_BENCHMARK_STEPS: Final[int] = 100
|
|
|
57
57
|
DEFAULT_WARMUP_STEPS: Final[int] = 10
|
|
58
58
|
|
|
59
59
|
# ========================
|
|
60
|
-
#
|
|
60
|
+
# TEXTURE THEMES
|
|
61
61
|
# ========================
|
|
62
62
|
|
|
63
|
+
class TextureThemes:
|
|
64
|
+
"""Pre-defined texture themes for different environments."""
|
|
65
|
+
|
|
66
|
+
NINE_ROOMS = [
|
|
67
|
+
"beige",
|
|
68
|
+
"lightbeige",
|
|
69
|
+
"lightgray",
|
|
70
|
+
"copperred",
|
|
71
|
+
"skyblue",
|
|
72
|
+
"lightcobaltgreen",
|
|
73
|
+
"oakbrown",
|
|
74
|
+
"navyblue",
|
|
75
|
+
"cobaltgreen",
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
SPIRAL_NINE_ROOMS = [
|
|
79
|
+
"beige",
|
|
80
|
+
"lightbeige",
|
|
81
|
+
"lightgray",
|
|
82
|
+
"copperred",
|
|
83
|
+
"skyblue",
|
|
84
|
+
"lightcobaltgreen",
|
|
85
|
+
"oakbrown",
|
|
86
|
+
"navyblue",
|
|
87
|
+
"cobaltgreen",
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
TWENTY_FIVE_ROOMS = [
|
|
91
|
+
"crimson",
|
|
92
|
+
"beanpaste",
|
|
93
|
+
"cobaltgreen",
|
|
94
|
+
"lightnavyblue",
|
|
95
|
+
"skyblue",
|
|
96
|
+
"lightcobaltgreen",
|
|
97
|
+
"oakbrown",
|
|
98
|
+
"copperred",
|
|
99
|
+
"lightgray",
|
|
100
|
+
"lime",
|
|
101
|
+
"turquoise",
|
|
102
|
+
"violet",
|
|
103
|
+
"beige",
|
|
104
|
+
"morningglory",
|
|
105
|
+
"silver",
|
|
106
|
+
"magenta",
|
|
107
|
+
"sunnyyellow",
|
|
108
|
+
"blueberry",
|
|
109
|
+
"lightbeige",
|
|
110
|
+
"seablue",
|
|
111
|
+
"lemongrass",
|
|
112
|
+
"orchid",
|
|
113
|
+
"redbean",
|
|
114
|
+
"orange",
|
|
115
|
+
"realblueberry",
|
|
116
|
+
]
|
|
63
117
|
|
|
64
|
-
# Convenient position calculations for standard room layouts
|
|
65
|
-
class RoomPositions:
|
|
66
|
-
"""Pre-calculated positions for standard room layouts."""
|
|
67
|
-
|
|
68
|
-
# NineRooms (3x3) strategic positions
|
|
69
|
-
NINE_ROOMS_CENTER: Final[tuple[float, float, float]] = (22.5, 0.0, 22.5)
|
|
70
|
-
NINE_ROOMS_TOP_LEFT: Final[tuple[float, float, float]] = (7.5, 0.0, 7.5)
|
|
71
|
-
NINE_ROOMS_BOTTOM_RIGHT: Final[tuple[float, float, float]] = (37.5, 0.0, 37.5)
|
|
72
|
-
|
|
73
|
-
# TwentyFiveRooms (5x5) strategic positions
|
|
74
|
-
TWENTY_FIVE_CENTER: Final[tuple[float, float, float]] = (75.0, 0.0, 75.0)
|
|
75
|
-
TWENTY_FIVE_CORNER: Final[tuple[float, float, float]] = (37.5, 0.0, 37.5)
|
|
76
|
-
TWENTY_FIVE_FAR_CORNER: Final[tuple[float, float, float]] = (112.5, 0.0, 112.5)
|
|
77
118
|
|
|
78
119
|
|
|
79
120
|
# ========================
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import math
|
|
4
4
|
from ctypes import POINTER
|
|
5
5
|
from enum import IntEnum
|
|
6
|
+
from typing import List, Optional
|
|
6
7
|
|
|
7
8
|
import gymnasium as gym
|
|
8
9
|
import numpy as np
|
|
@@ -10,6 +11,7 @@ import pyglet
|
|
|
10
11
|
from gymnasium import spaces
|
|
11
12
|
from pyglet.gl import *
|
|
12
13
|
|
|
14
|
+
from ..observation_types import ObservationLevel
|
|
13
15
|
from .entities import *
|
|
14
16
|
from .math import *
|
|
15
17
|
from .objmesh import *
|
|
@@ -95,6 +97,7 @@ class UnifiedMiniWorldEnv(gym.Env):
|
|
|
95
97
|
window_height=DEFAULT_WINDOW_HEIGHT,
|
|
96
98
|
params=DEFAULT_PARAMS,
|
|
97
99
|
domain_rand=False,
|
|
100
|
+
info_obs: Optional[List[ObservationLevel]] = None,
|
|
98
101
|
):
|
|
99
102
|
"""
|
|
100
103
|
Initialize unified MiniWorld environment.
|
|
@@ -110,6 +113,7 @@ class UnifiedMiniWorldEnv(gym.Env):
|
|
|
110
113
|
window_height: Window height for human rendering
|
|
111
114
|
params: Environment parameters for domain randomization
|
|
112
115
|
domain_rand: Whether to enable domain randomization
|
|
116
|
+
info_obs: List of observation levels to include in info dictionary
|
|
113
117
|
"""
|
|
114
118
|
# Store configuration
|
|
115
119
|
self.obs_level = obs_level
|
|
@@ -118,6 +122,7 @@ class UnifiedMiniWorldEnv(gym.Env):
|
|
|
118
122
|
self.max_episode_steps = max_episode_steps
|
|
119
123
|
self.params = params
|
|
120
124
|
self.domain_rand = domain_rand
|
|
125
|
+
self.info_obs = info_obs
|
|
121
126
|
|
|
122
127
|
# Setup action space
|
|
123
128
|
self._setup_action_space()
|
|
@@ -327,22 +332,32 @@ class UnifiedMiniWorldEnv(gym.Env):
|
|
|
327
332
|
self._render_static()
|
|
328
333
|
|
|
329
334
|
# Generate the first camera image
|
|
330
|
-
obs = self._generate_observation()
|
|
335
|
+
obs = self._generate_observation(self.obs_level)
|
|
336
|
+
|
|
337
|
+
# Generate additional observations for info dictionary if specified
|
|
338
|
+
info = {}
|
|
339
|
+
if self.info_obs is not None:
|
|
340
|
+
for obs_level in self.info_obs:
|
|
341
|
+
# Generate observation with the specified level
|
|
342
|
+
info_obs = self._generate_observation(observation_level=obs_level)
|
|
343
|
+
# Use the observation level name as key
|
|
344
|
+
info[str(obs_level)] = info_obs
|
|
331
345
|
|
|
332
346
|
# Return first observation with info dict for Gymnasium compatibility
|
|
333
|
-
return obs,
|
|
347
|
+
return obs, info
|
|
334
348
|
|
|
335
|
-
def _generate_observation(self, render_agent: bool = None):
|
|
336
|
-
"""Generate observation based on
|
|
349
|
+
def _generate_observation(self, observation_level, render_agent: bool = None):
|
|
350
|
+
"""Generate observation based on specified observation level.
|
|
337
351
|
|
|
338
352
|
Args:
|
|
353
|
+
observation_level: Observation level to use.
|
|
339
354
|
render_agent: Whether to render the agent in the observation.
|
|
340
355
|
If None, uses default behavior based on observation level.
|
|
341
356
|
"""
|
|
342
357
|
# Import ObservationLevel here to avoid circular imports
|
|
343
358
|
from ..observation_types import ObservationLevel
|
|
344
359
|
|
|
345
|
-
if
|
|
360
|
+
if observation_level == ObservationLevel.TOP_DOWN_PARTIAL:
|
|
346
361
|
if self.agent_mode == "empty":
|
|
347
362
|
# Agent mode 'empty' always renders without agent
|
|
348
363
|
render_ag = False
|
|
@@ -354,33 +369,21 @@ class UnifiedMiniWorldEnv(gym.Env):
|
|
|
354
369
|
render_ag = True
|
|
355
370
|
return self.render_top_view(POMDP=True, render_ag=render_ag)
|
|
356
371
|
|
|
357
|
-
elif
|
|
372
|
+
elif observation_level == ObservationLevel.TOP_DOWN_FULL:
|
|
358
373
|
# Use explicit render_agent parameter or default to True
|
|
359
374
|
render_ag = render_agent if render_agent is not None else True
|
|
360
375
|
return self.render_top_view(POMDP=False, render_ag=render_ag)
|
|
361
376
|
|
|
362
|
-
elif
|
|
377
|
+
elif observation_level == ObservationLevel.FIRST_PERSON:
|
|
363
378
|
# First person view doesn't include the agent anyway
|
|
364
379
|
return self.render_obs()
|
|
365
380
|
|
|
366
381
|
else:
|
|
367
382
|
valid_levels = list(ObservationLevel)
|
|
368
383
|
raise ValueError(
|
|
369
|
-
f"Invalid obs_level {
|
|
384
|
+
f"Invalid obs_level {observation_level}. Must be one of {valid_levels}"
|
|
370
385
|
)
|
|
371
386
|
|
|
372
|
-
def get_observation(self, render_agent: bool = None):
|
|
373
|
-
"""Public method to generate observation with optional agent rendering control.
|
|
374
|
-
|
|
375
|
-
Args:
|
|
376
|
-
render_agent: Whether to render the agent in the observation.
|
|
377
|
-
If None, uses default behavior based on observation level.
|
|
378
|
-
|
|
379
|
-
Returns:
|
|
380
|
-
np.ndarray: Generated observation image
|
|
381
|
-
"""
|
|
382
|
-
return self._generate_observation(render_agent=render_agent)
|
|
383
|
-
|
|
384
387
|
def _calculate_carried_object_position(self, agent_pos, ent):
|
|
385
388
|
"""Compute the position at which to place an object being carried."""
|
|
386
389
|
dist = self.agent.radius + ent.radius + self.max_forward_step
|
|
@@ -502,7 +505,7 @@ class UnifiedMiniWorldEnv(gym.Env):
|
|
|
502
505
|
self._process_action(action)
|
|
503
506
|
|
|
504
507
|
# Generate observation
|
|
505
|
-
observation = self._generate_observation()
|
|
508
|
+
observation = self._generate_observation(self.obs_level)
|
|
506
509
|
|
|
507
510
|
# Calculate step results
|
|
508
511
|
reward, terminated, info = self._calculate_step_results(observation)
|
|
@@ -577,21 +580,34 @@ class UnifiedMiniWorldEnv(gym.Env):
|
|
|
577
580
|
if self.obs_level != 2: # Not TOP_DOWN_FULL
|
|
578
581
|
topdown = self.render_top_view(POMDP=False, frame_buffer=self.topdown_fb)
|
|
579
582
|
|
|
583
|
+
# Generate additional observations for info dictionary if specified
|
|
584
|
+
info = {}
|
|
585
|
+
if self.info_obs is not None:
|
|
586
|
+
for obs_level in self.info_obs:
|
|
587
|
+
# Generate observation with the specified level
|
|
588
|
+
info_obs = self._generate_observation(observation_level=obs_level)
|
|
589
|
+
# Use the observation level name as key
|
|
590
|
+
info[str(obs_level)] = info_obs
|
|
591
|
+
|
|
580
592
|
# Check termination
|
|
581
593
|
if self.step_count >= self.max_episode_steps:
|
|
582
594
|
terminated = True
|
|
583
595
|
reward = 0
|
|
584
|
-
info
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
596
|
+
info.update(
|
|
597
|
+
{
|
|
598
|
+
"pos": self.agent.pos,
|
|
599
|
+
"mdp_view": topdown if topdown is not None else observation,
|
|
600
|
+
}
|
|
601
|
+
)
|
|
588
602
|
else:
|
|
589
603
|
reward = 0
|
|
590
604
|
terminated = False
|
|
591
|
-
info
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
605
|
+
info.update(
|
|
606
|
+
{
|
|
607
|
+
"pos": self.agent.pos,
|
|
608
|
+
"mdp_view": topdown if topdown is not None else observation,
|
|
609
|
+
}
|
|
610
|
+
)
|
|
595
611
|
|
|
596
612
|
return reward, terminated, info
|
|
597
613
|
|
|
@@ -1,15 +1,12 @@
|
|
|
1
1
|
"""Nine Rooms environment implementations."""
|
|
2
2
|
|
|
3
3
|
from .base_grid_rooms import GridRoomsEnvironment
|
|
4
|
-
from .factory import NineRoomsEnvironmentWrapper, create_nine_rooms_env
|
|
5
4
|
from .nine_rooms import NineRooms
|
|
6
5
|
from .spiral_nine_rooms import SpiralNineRooms
|
|
7
6
|
from .twenty_five_rooms import TwentyFiveRooms
|
|
8
7
|
|
|
9
8
|
__all__ = [
|
|
10
9
|
"GridRoomsEnvironment",
|
|
11
|
-
"create_nine_rooms_env",
|
|
12
|
-
"NineRoomsEnvironmentWrapper",
|
|
13
10
|
"NineRooms",
|
|
14
11
|
"SpiralNineRooms",
|
|
15
12
|
"TwentyFiveRooms",
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
|
+
import cv2
|
|
6
|
+
import numpy as np
|
|
5
7
|
from gymnasium import spaces
|
|
6
8
|
|
|
7
9
|
from ..core import COLORS, Box, ObservationLevel
|
|
@@ -35,6 +37,7 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
|
|
|
35
37
|
grid_size: int,
|
|
36
38
|
connections: List[Tuple[int, int]],
|
|
37
39
|
textures: List[str],
|
|
40
|
+
goal_positions: List[List[List[float]]],
|
|
38
41
|
placed_room: Optional[int] = None,
|
|
39
42
|
obs_level: ObservationLevel = ObservationLevel.TOP_DOWN_PARTIAL,
|
|
40
43
|
continuous: bool = False,
|
|
@@ -52,6 +55,7 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
|
|
|
52
55
|
grid_size: Size of the grid (e.g., 3 for 3x3 grid)
|
|
53
56
|
connections: List of (room1, room2) tuples for connections
|
|
54
57
|
textures: List of texture names for each room
|
|
58
|
+
goal_positions: List of goal positions for each room
|
|
55
59
|
placed_room: Initial room index (defaults to 0)
|
|
56
60
|
obs_level: Observation level (defaults to 1)
|
|
57
61
|
continuous: Whether to use continuous actions (defaults to False)
|
|
@@ -73,10 +77,14 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
|
|
|
73
77
|
|
|
74
78
|
# Validate and set textures
|
|
75
79
|
assert len(textures) == self.total_rooms, (
|
|
76
|
-
f"Textures for floor should be same as the number of the rooms
|
|
80
|
+
f"Textures for floor should be same as the number of the rooms "
|
|
81
|
+
f"({self.total_rooms})"
|
|
77
82
|
)
|
|
78
83
|
self.textures = textures
|
|
79
84
|
|
|
85
|
+
# Set goal positions
|
|
86
|
+
self.goal_positions = goal_positions
|
|
87
|
+
|
|
80
88
|
# Set placed room
|
|
81
89
|
if placed_room is None:
|
|
82
90
|
self.placed_room = 0 # Start in the first room
|
|
@@ -101,6 +109,10 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
|
|
|
101
109
|
# Mark this as a custom environment for background color handling
|
|
102
110
|
self._is_custom_env = True
|
|
103
111
|
|
|
112
|
+
# Store observation dimensions for rendering (needed before super().__init__)
|
|
113
|
+
self.obs_width = obs_width
|
|
114
|
+
self.obs_height = obs_height
|
|
115
|
+
|
|
104
116
|
super().__init__(
|
|
105
117
|
obs_level=obs_level,
|
|
106
118
|
max_episode_steps=MAX_EPISODE_STEPS,
|
|
@@ -114,6 +126,18 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
|
|
|
114
126
|
if not self.continuous:
|
|
115
127
|
self.action_space = spaces.Discrete(self.actions.move_forward + 1)
|
|
116
128
|
|
|
129
|
+
# Store original observation space before updating
|
|
130
|
+
original_obs_space = self.observation_space
|
|
131
|
+
|
|
132
|
+
# Update observation space to include desired_goal and achieved_goal
|
|
133
|
+
self.observation_space = spaces.Dict(
|
|
134
|
+
{
|
|
135
|
+
"observation": original_obs_space,
|
|
136
|
+
"desired_goal": original_obs_space,
|
|
137
|
+
"achieved_goal": original_obs_space,
|
|
138
|
+
}
|
|
139
|
+
)
|
|
140
|
+
|
|
117
141
|
def _generate_world_layout(self, pos=None):
|
|
118
142
|
rooms = []
|
|
119
143
|
|
|
@@ -201,4 +225,191 @@ class GridRoomsEnvironment(UnifiedMiniWorldEnv):
|
|
|
201
225
|
|
|
202
226
|
def step(self, action):
|
|
203
227
|
obs, reward, terminated, truncated, info = super().step(action)
|
|
204
|
-
|
|
228
|
+
|
|
229
|
+
# Check if goal is achieved
|
|
230
|
+
if self._is_goal_achieved():
|
|
231
|
+
terminated = True
|
|
232
|
+
reward = 1.0 # Positive reward for achieving goal
|
|
233
|
+
|
|
234
|
+
# Add agent and goal positions to info dictionary
|
|
235
|
+
agent_pos = self.agent.pos
|
|
236
|
+
info["agent_position"] = np.array([agent_pos[0], agent_pos[2]]) # x, z
|
|
237
|
+
|
|
238
|
+
if hasattr(self, "_current_goal_position"):
|
|
239
|
+
goal_pos = self._current_goal_position
|
|
240
|
+
info["goal_position"] = np.array([goal_pos[0], goal_pos[2]]) # x, z
|
|
241
|
+
|
|
242
|
+
# Return observation as dict
|
|
243
|
+
obs_dict = self._build_observation_dict(obs)
|
|
244
|
+
return obs_dict, reward, terminated, truncated, info
|
|
245
|
+
|
|
246
|
+
def reset(self, seed=None, options=None, pos=None):
|
|
247
|
+
"""
|
|
248
|
+
Reset the environment and generate a new goal.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
seed: Random seed
|
|
252
|
+
options: Additional options
|
|
253
|
+
pos: Agent starting position
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
tuple: (observation, info)
|
|
257
|
+
"""
|
|
258
|
+
# Call parent reset
|
|
259
|
+
obs, info = super().reset(seed=seed, options=options, pos=pos)
|
|
260
|
+
|
|
261
|
+
# Generate goal
|
|
262
|
+
self.desired_goal = self._get_goal()
|
|
263
|
+
|
|
264
|
+
# Add agent and goal positions to info dictionary
|
|
265
|
+
agent_pos = self.agent.pos
|
|
266
|
+
info["agent_position"] = np.array([agent_pos[0], agent_pos[2]]) # x, z
|
|
267
|
+
|
|
268
|
+
if hasattr(self, "_current_goal_position"):
|
|
269
|
+
goal_pos = self._current_goal_position
|
|
270
|
+
info["goal_position"] = np.array([goal_pos[0], goal_pos[2]]) # x, z
|
|
271
|
+
|
|
272
|
+
# Return observation as dict with desired_goal and achieved_goal
|
|
273
|
+
obs_dict = self._build_observation_dict(obs)
|
|
274
|
+
return obs_dict, info
|
|
275
|
+
|
|
276
|
+
def _get_goal(self):
|
|
277
|
+
"""
|
|
278
|
+
Generate a goal by randomly selecting a room and goal position.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
np.ndarray: Rendered goal image
|
|
282
|
+
"""
|
|
283
|
+
# Select random room
|
|
284
|
+
room_idx = np.random.randint(len(self.goal_positions))
|
|
285
|
+
|
|
286
|
+
# Select random goal within room
|
|
287
|
+
goal_idx = np.random.randint(len(self.goal_positions[room_idx]))
|
|
288
|
+
|
|
289
|
+
# Get goal position
|
|
290
|
+
goal_position = self.goal_positions[room_idx][goal_idx]
|
|
291
|
+
self._current_goal_position = goal_position
|
|
292
|
+
self._current_goal_room = room_idx
|
|
293
|
+
self._current_goal_idx = goal_idx
|
|
294
|
+
|
|
295
|
+
# Render goal image
|
|
296
|
+
goal_image = self.render_on_pos(goal_position)
|
|
297
|
+
|
|
298
|
+
return goal_image
|
|
299
|
+
|
|
300
|
+
def render_on_pos(self, pos):
|
|
301
|
+
"""
|
|
302
|
+
Render observation from a specific position.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
pos: Position to render from [x, y, z]
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
np.ndarray: Rendered observation
|
|
309
|
+
"""
|
|
310
|
+
# Store current agent position
|
|
311
|
+
current_pos = self.agent.pos.copy()
|
|
312
|
+
|
|
313
|
+
# Move agent to target position
|
|
314
|
+
self.place_agent(pos=pos)
|
|
315
|
+
|
|
316
|
+
# Render observation from this position
|
|
317
|
+
obs = self.render_top_view(POMDP=True, render_ag=False)
|
|
318
|
+
|
|
319
|
+
# Resize to match observation dimensions if needed
|
|
320
|
+
if obs.shape[:2] != (self.obs_height, self.obs_width):
|
|
321
|
+
obs = cv2.resize(
|
|
322
|
+
obs, (self.obs_width, self.obs_height), interpolation=cv2.INTER_AREA
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Restore agent position
|
|
326
|
+
self.place_agent(pos=current_pos)
|
|
327
|
+
|
|
328
|
+
return obs
|
|
329
|
+
|
|
330
|
+
def _is_goal_achieved(self, pos=None, threshold=0.5):
|
|
331
|
+
"""
|
|
332
|
+
Check if the agent has achieved the current goal.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
pos: Agent position to check (uses current agent pos if None)
|
|
336
|
+
threshold: Distance threshold for goal achievement
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
bool: True if goal is achieved
|
|
340
|
+
"""
|
|
341
|
+
if pos is None:
|
|
342
|
+
pos = self.agent.pos
|
|
343
|
+
|
|
344
|
+
if not hasattr(self, "_current_goal_position"):
|
|
345
|
+
return False
|
|
346
|
+
|
|
347
|
+
# Convert to numpy arrays and calculate distance
|
|
348
|
+
pos_array = np.array(pos)
|
|
349
|
+
goal_array = np.array(self._current_goal_position)
|
|
350
|
+
distance = np.linalg.norm(pos_array - goal_array)
|
|
351
|
+
|
|
352
|
+
return bool(distance < threshold)
|
|
353
|
+
|
|
354
|
+
@staticmethod
|
|
355
|
+
def _generate_goal_positions(
|
|
356
|
+
grid_size: int, room_size: Union[int, float], goals_per_room: int = 2
|
|
357
|
+
) -> List[List[List[float]]]:
|
|
358
|
+
"""
|
|
359
|
+
Generate goal positions for grid layout.
|
|
360
|
+
Args:
|
|
361
|
+
grid_size: Size of the grid (e.g., 3 for 3x3, 5 for 5x5)
|
|
362
|
+
room_size: Size of each room
|
|
363
|
+
goals_per_room: Number of goals per room (1 or 2)
|
|
364
|
+
Returns:
|
|
365
|
+
List of goal positions for each room
|
|
366
|
+
"""
|
|
367
|
+
goal_positions = []
|
|
368
|
+
for i in range(grid_size): # rows
|
|
369
|
+
for j in range(grid_size): # columns
|
|
370
|
+
center_x = room_size * j + room_size / 2
|
|
371
|
+
center_z = room_size * i + room_size / 2
|
|
372
|
+
if goals_per_room == 1:
|
|
373
|
+
# One goal per room at the center
|
|
374
|
+
goal_positions.append([[center_x, 0.0, center_z]])
|
|
375
|
+
else:
|
|
376
|
+
# Two goals per room: center-left and center-right
|
|
377
|
+
goal_positions.append([
|
|
378
|
+
[center_x - 1.0, 0.0, center_z], # left goal
|
|
379
|
+
[center_x + 1.0, 0.0, center_z], # right goal
|
|
380
|
+
])
|
|
381
|
+
return goal_positions
|
|
382
|
+
|
|
383
|
+
def get_extent(self, padding: float = 1.0) -> Tuple[float, float, float, float]:
|
|
384
|
+
"""
|
|
385
|
+
Get the scene extent for use with matplotlib imshow.
|
|
386
|
+
|
|
387
|
+
Returns the scene bounds with padding in the format expected by
|
|
388
|
+
matplotlib's imshow(extent=...) parameter: (left, right, bottom, top).
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
padding: Padding to add around environment bounds (default: 1.0)
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
Tuple[float, float, float, float]: (min_x, max_x, min_z, max_z) with padding
|
|
395
|
+
"""
|
|
396
|
+
return (
|
|
397
|
+
self.min_x - padding,
|
|
398
|
+
self.max_x + padding,
|
|
399
|
+
self.min_z - padding,
|
|
400
|
+
self.max_z + padding
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
def _build_observation_dict(self, obs: np.ndarray) -> dict:
|
|
404
|
+
"""
|
|
405
|
+
Build the standard observation dictionary format.
|
|
406
|
+
Args:
|
|
407
|
+
obs: The observation array
|
|
408
|
+
Returns:
|
|
409
|
+
Dictionary with observation, desired_goal, and achieved_goal
|
|
410
|
+
"""
|
|
411
|
+
return {
|
|
412
|
+
"observation": obs,
|
|
413
|
+
"desired_goal": self.desired_goal,
|
|
414
|
+
"achieved_goal": obs,
|
|
415
|
+
}
|