duobench 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- duobench/__init__.py +119 -0
- duobench/__main__.py +20 -0
- duobench/tasks/__init__.py +81 -0
- duobench/tasks/ball_maze.py +251 -0
- duobench/tasks/bin_sort.py +276 -0
- duobench/tasks/block_balance.py +265 -0
- duobench/tasks/carry_pot.py +264 -0
- duobench/tasks/example.py +43 -0
- duobench/tasks/hinge_chest.py +225 -0
- duobench/tasks/join_blocks.py +189 -0
- duobench/tasks/pour_marbles.py +419 -0
- duobench/tasks/spring_door.py +167 -0
- duobench/tasks/transfer_cube.py +246 -0
- duobench/tasks/transfer_gate.py +247 -0
- duobench/tasks/transfer_reorient.py +221 -0
- duobench/utils/__init__.py +0 -0
- duobench/utils/helper_wrappers.py +204 -0
- duobench/utils/vention_config.py +42 -0
- duobench-0.1.0.dist-info/METADATA +211 -0
- duobench-0.1.0.dist-info/RECORD +26 -0
- duobench-0.1.0.dist-info/WHEEL +5 -0
- duobench-0.1.0.dist-info/licenses/LICENSE +201 -0
- duobench-0.1.0.dist-info/top_level.txt +2 -0
- tests/test_ball_maze_stage.py +12 -0
- tests/test_random_square_objs_pos.py +91 -0
- tests/test_replay_stage_detection.py +125 -0
duobench/__init__.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import rcs
|
|
8
|
+
from rcs import get_prefix
|
|
9
|
+
from rcs._core import common
|
|
10
|
+
|
|
11
|
+
DUOBENCH_GITHUB_ASSET_ARCHIVE_URL = "https://github.com/RobotControlStack/duobench/archive/refs/tags/{tag}.zip"
|
|
12
|
+
|
|
13
|
+
DUOBENCH_PREFIX = get_prefix(
|
|
14
|
+
os.environ.get("DUOBENCH_PREFIX"),
|
|
15
|
+
Path(__file__).resolve().parents[2],
|
|
16
|
+
Path.home() / ".duobench",
|
|
17
|
+
__version__,
|
|
18
|
+
DUOBENCH_GITHUB_ASSET_ARCHIVE_URL,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
CAMERA_PATHS: dict[str, str] = {}
|
|
22
|
+
GRIPPER_PATHS: dict[str, str] = {}
|
|
23
|
+
OBJECT_PATHS: dict[str, str] = {
|
|
24
|
+
# transfer_reorient
|
|
25
|
+
"transfer_reorient_socket": "assets/objects/transfer_reorient/socket.xml",
|
|
26
|
+
"transfer_reorient_hex_peg": "assets/objects/transfer_reorient/hexpeg.xml",
|
|
27
|
+
"transfer_reorient_squere_peg": "assets/objects/transfer_reorient/squere_peg.xml",
|
|
28
|
+
"transfer_reorient_triangel_peg": "assets/objects/transfer_reorient/triangle_peg.xml",
|
|
29
|
+
# block_stacking
|
|
30
|
+
"block_stacking_big_red_cube": "assets/objects/block_balance/big_red_cube.xml",
|
|
31
|
+
"block_stacking_blue_rectangle": "assets/objects/block_balance/blue_rectangle.xml",
|
|
32
|
+
"block_stacking_green_rectangle": "assets/objects/block_balance/green_rectangle.xml",
|
|
33
|
+
"block_stacking_pink_beam": "assets/objects/block_balance/pink_beam.xml",
|
|
34
|
+
# bin_sort
|
|
35
|
+
"parallel_pick_white_bowl": "assets/objects/bin_sort/white_bowl.xml",
|
|
36
|
+
"parallel_pick_white_tri_cylinder": "assets/objects/bin_sort/white_tri_cylinder.xml",
|
|
37
|
+
"parallel_pick_white_box": "assets/objects/bin_sort/white_box.xml",
|
|
38
|
+
"parallel_pick_white_pent_cylinder": "assets/objects/bin_sort/white_pent_cylinder.xml",
|
|
39
|
+
"parallel_pick_white_hex_cylinder": "assets/objects/bin_sort/white_hex_cylinder.xml",
|
|
40
|
+
"parallel_pick_white_oct_cylinder": "assets/objects/bin_sort/white_oct_cylinder.xml",
|
|
41
|
+
"parallel_pick_white_dec_cylinder": "assets/objects/bin_sort/white_dec_cylinder.xml",
|
|
42
|
+
"parallel_pick_white_cylinder": "assets/objects/bin_sort/white_cylinder.xml",
|
|
43
|
+
"parallel_pick_black_bowl": "assets/objects/bin_sort/black_bowl.xml",
|
|
44
|
+
"parallel_pick_black_tri_cylinder": "assets/objects/bin_sort/black_tri_cylinder.xml",
|
|
45
|
+
"parallel_pick_black_box": "assets/objects/bin_sort/black_box.xml",
|
|
46
|
+
"parallel_pick_black_pent_cylinder": "assets/objects/bin_sort/black_pent_cylinder.xml",
|
|
47
|
+
"parallel_pick_black_hex_cylinder": "assets/objects/bin_sort/black_hex_cylinder.xml",
|
|
48
|
+
"parallel_pick_black_oct_cylinder": "assets/objects/bin_sort/black_oct_cylinder.xml",
|
|
49
|
+
"parallel_pick_black_dec_cylinder": "assets/objects/bin_sort/black_dec_cylinder.xml",
|
|
50
|
+
"parallel_pick_black_cylinder": "assets/objects/bin_sort/black_cylinder.xml",
|
|
51
|
+
"parallel_pick_blue_tri_cylinder": "assets/objects/bin_sort/blue_tri_cylinder.xml",
|
|
52
|
+
"parallel_pick_blue_box": "assets/objects/bin_sort/blue_box.xml",
|
|
53
|
+
"parallel_pick_blue_pent_cylinder": "assets/objects/bin_sort/blue_pent_cylinder.xml",
|
|
54
|
+
"parallel_pick_blue_hex_cylinder": "assets/objects/bin_sort/blue_hex_cylinder.xml",
|
|
55
|
+
"parallel_pick_blue_cylinder": "assets/objects/bin_sort/blue_cylinder.xml",
|
|
56
|
+
"parallel_pick_green_tri_cylinder": "assets/objects/bin_sort/green_tri_cylinder.xml",
|
|
57
|
+
"parallel_pick_green_box": "assets/objects/bin_sort/green_box.xml",
|
|
58
|
+
"parallel_pick_green_pent_cylinder": "assets/objects/bin_sort/green_pent_cylinder.xml",
|
|
59
|
+
"parallel_pick_green_hex_cylinder": "assets/objects/bin_sort/green_hex_cylinder.xml",
|
|
60
|
+
"parallel_pick_green_cylinder": "assets/objects/bin_sort/green_cylinder.xml",
|
|
61
|
+
# spring_door
|
|
62
|
+
"spring_door_gray_microwave": "assets/objects/spring_door/gray_microwave.xml",
|
|
63
|
+
"spring_door_mug": "assets/objects/spring_door/mug.xml",
|
|
64
|
+
"spring_door_simple_mug": "assets/objects/spring_door/simple_mug.xml",
|
|
65
|
+
# hinge chest
|
|
66
|
+
"hinge_chest_chest": "assets/objects/hinge_chest/chest.xml",
|
|
67
|
+
# transfer_gate
|
|
68
|
+
"handover_hole_stand": "assets/objects/transfer_gate/stand.xml",
|
|
69
|
+
"handover_hole_mat": "assets/objects/transfer_gate/mat.xml",
|
|
70
|
+
"handover_hole_white_long_box": "assets/objects/transfer_gate/white_long_box.xml",
|
|
71
|
+
# carry pot
|
|
72
|
+
"carry_pot_pot": "assets/objects/carry_pot/pot.xml",
|
|
73
|
+
"carry_pot_stove": "assets/objects/carry_pot/stove.xml",
|
|
74
|
+
# join_blocks
|
|
75
|
+
"h_block": "assets/objects/join_blocks/join_blocks_h_block.xml",
|
|
76
|
+
"p_block": "assets/objects/join_blocks/join_blocks_p_block.xml",
|
|
77
|
+
"wall": "assets/objects/join_blocks/join_blocks_wall.xml",
|
|
78
|
+
# pour_marbles
|
|
79
|
+
"marble": "assets/objects/pour_marbles/marble.xml",
|
|
80
|
+
"teacup": "assets/objects/pour_marbles/teacup.xml",
|
|
81
|
+
# ball_maze
|
|
82
|
+
"maze1": "assets/objects/ball_maze/maze_simple_1.xml",
|
|
83
|
+
"maze2": "assets/objects/ball_maze/maze_simple_2.xml",
|
|
84
|
+
"maze3": "assets/objects/ball_maze/maze_simple_3.xml",
|
|
85
|
+
"maze4": "assets/objects/ball_maze/maze_simple_4.xml",
|
|
86
|
+
"maze5": "assets/objects/ball_maze/maze_simple_5.xml",
|
|
87
|
+
"maze6": "assets/objects/ball_maze/maze_simple_6.xml",
|
|
88
|
+
"maze7": "assets/objects/ball_maze/maze_simple_7.xml",
|
|
89
|
+
"maze8": "assets/objects/ball_maze/maze_simple_8.xml",
|
|
90
|
+
"maze9": "assets/objects/ball_maze/maze_simple_9.xml",
|
|
91
|
+
"maze10": "assets/objects/ball_maze/maze_simple_10.xml",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
SCENE_PATHS = {
|
|
96
|
+
"empty_world": "assets/scenes/base_world.xml",
|
|
97
|
+
"vention_world": "assets/scenes/vention_table/vention_world.xml",
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
# Append DUOBENCH package prefix to all asset paths
|
|
101
|
+
for path_dict in (GRIPPER_PATHS, SCENE_PATHS, OBJECT_PATHS, CAMERA_PATHS):
|
|
102
|
+
for name, path in path_dict.items():
|
|
103
|
+
abs_path = os.path.join(DUOBENCH_PREFIX, path)
|
|
104
|
+
if not os.path.isfile(abs_path):
|
|
105
|
+
error_msg = f"Asset {name} not found at path: {abs_path}. Please make sure to download the assets."
|
|
106
|
+
raise FileNotFoundError(error_msg)
|
|
107
|
+
else:
|
|
108
|
+
path_dict[name] = abs_path
|
|
109
|
+
|
|
110
|
+
# Update the global RCS registries with our assets and tasks
|
|
111
|
+
rcs.SCENE_PATHS.update(SCENE_PATHS)
|
|
112
|
+
rcs.OBJECT_PATHS.update(OBJECT_PATHS)
|
|
113
|
+
rcs.GRIPPER_PATHS.update(GRIPPER_PATHS)
|
|
114
|
+
rcs.CAMERA_PATHS.update(CAMERA_PATHS)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
rcs.DEFAULT_TRANSFORMS.update(
|
|
118
|
+
{"VENTION_HEIGHT_OFFSET": common.Pose(translation=np.array([0, 0, 0.8095]), quaternion=np.array([0, 0, 0, 1]))}
|
|
119
|
+
)
|
duobench/__main__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import pkgutil
|
|
2
|
+
from importlib import import_module
|
|
3
|
+
|
|
4
|
+
import duobench.tasks
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _register_duobench_tasks() -> None:
|
|
8
|
+
for module in pkgutil.iter_modules(duobench.tasks.__path__):
|
|
9
|
+
import_module(f"{duobench.tasks.__name__}.{module.name}")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def main() -> None:
|
|
13
|
+
# Import the base RCS CLI only after DuoBench tasks are registered so the
|
|
14
|
+
# existing commands can resolve `duobench/*` env ids.
|
|
15
|
+
_register_duobench_tasks()
|
|
16
|
+
import_module("rcs.__main__").app()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if __name__ == "__main__":
|
|
20
|
+
main()
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import gymnasium as gym
|
|
6
|
+
from rcs import sim as simulation
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(kw_only=True)
|
|
10
|
+
class TaskStage(ABC):
|
|
11
|
+
stage: int = 0
|
|
12
|
+
max_stage: int = 0
|
|
13
|
+
internal_state: dict[str, bool] = field(default_factory=dict)
|
|
14
|
+
stage_to_subinstructions: dict[int, str] = field(default_factory=dict)
|
|
15
|
+
instruction: str = ""
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def update_stage(self):
|
|
19
|
+
msg = "TaskStage is an abstract class. Please implement the update_stage method in a subclass."
|
|
20
|
+
raise NotImplementedError(msg)
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def update_internal_state(self, sim: simulation.Sim):
|
|
24
|
+
msg = "TaskStage is an abstract class. Please implement the update_internal_state method in a subclass."
|
|
25
|
+
raise NotImplementedError(msg)
|
|
26
|
+
|
|
27
|
+
def reset(self):
|
|
28
|
+
self.stage = 0
|
|
29
|
+
for key in self.internal_state:
|
|
30
|
+
self.internal_state[key] = False
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def current_subinstruction(self) -> str:
|
|
34
|
+
return self.stage_to_subinstructions[self.stage] if self.stage_to_subinstructions is not None else None
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def success(self) -> bool:
|
|
38
|
+
return self.stage == self.max_stage
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def normalized_stage(self) -> float:
|
|
42
|
+
return self.stage / self.max_stage
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def info(self) -> dict[str, Any]:
|
|
46
|
+
return {
|
|
47
|
+
"success": self.success,
|
|
48
|
+
"stage": self.stage,
|
|
49
|
+
"max_stage": self.max_stage,
|
|
50
|
+
"current_subinstruction": self.current_subinstruction,
|
|
51
|
+
"stage_to_subinstructions": {
|
|
52
|
+
str(stage): instruction for stage, instruction in self.stage_to_subinstructions.items()
|
|
53
|
+
},
|
|
54
|
+
"instruction": self.instruction,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TaskStageWrapper(gym.Wrapper):
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
env: gym.Env,
|
|
63
|
+
stage_tracker: TaskStage,
|
|
64
|
+
):
|
|
65
|
+
super().__init__(env)
|
|
66
|
+
self.stage_tracker = stage_tracker
|
|
67
|
+
|
|
68
|
+
def step(self, action: dict[str, Any]):
|
|
69
|
+
obs, _, _, truncated, info = super().step(action)
|
|
70
|
+
self.stage_tracker.update_internal_state(self.get_wrapper_attr("sim"))
|
|
71
|
+
|
|
72
|
+
info.update(self.stage_tracker.info)
|
|
73
|
+
reward = self.stage_tracker.normalized_stage
|
|
74
|
+
|
|
75
|
+
return obs, reward, self.stage_tracker.success, truncated, info
|
|
76
|
+
|
|
77
|
+
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
|
|
78
|
+
self.stage_tracker.reset()
|
|
79
|
+
obs, info = super().reset(seed=seed, options=options)
|
|
80
|
+
info.update(self.stage_tracker.info)
|
|
81
|
+
return obs, info
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import gymnasium as gym
|
|
5
|
+
import mujoco as mj
|
|
6
|
+
import numpy as np
|
|
7
|
+
import rcs
|
|
8
|
+
from rcs.envs.scenes import BaseTaskConfig, SimEnvCreatorConfig, Task
|
|
9
|
+
from rcs.sim.composer import ModelComposer
|
|
10
|
+
from rcs.sim.sim import Sim
|
|
11
|
+
|
|
12
|
+
from duobench.tasks import TaskStage, TaskStageWrapper
|
|
13
|
+
from duobench.utils.helper_wrappers import (
|
|
14
|
+
RandomSquareObjsPos,
|
|
15
|
+
get_bodies_in_contact_with_gripper_pad,
|
|
16
|
+
)
|
|
17
|
+
from duobench.utils.vention_config import VentionSceneFR3Duo
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(kw_only=True)
|
|
21
|
+
class BallMazeTaskConfig(BaseTaskConfig):
|
|
22
|
+
|
|
23
|
+
task_id: str = "ball_maze"
|
|
24
|
+
|
|
25
|
+
board_number = 2
|
|
26
|
+
object_body = "board"
|
|
27
|
+
include_rotation: bool = True
|
|
28
|
+
task_instructions = "pick up the board and tilt it so the ball roles onto the red square"
|
|
29
|
+
hard_reset = False
|
|
30
|
+
x_width = 0.3
|
|
31
|
+
y_width = 0.3
|
|
32
|
+
z_init = 0.02
|
|
33
|
+
obj_position_margin = 0.08
|
|
34
|
+
object_joint_names: list[str] = field(default_factory=lambda: ["board_board_joint"])
|
|
35
|
+
object2root_frame: rcs.common.Pose = field(
|
|
36
|
+
default_factory=lambda: rcs.common.Pose(translation=np.array([0.5, 0, 0.05]), quaternion=np.array([0, 0, 0, 1]))
|
|
37
|
+
)
|
|
38
|
+
object_center_to_root_frame: rcs.common.Pose = field(
|
|
39
|
+
default_factory=lambda: rcs.common.Pose(
|
|
40
|
+
# 0.02 = z_init
|
|
41
|
+
translation=np.array([0.5, 0.0, 0.02]),
|
|
42
|
+
quaternion=np.array([0, 0, 0, 1]),
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BallMazeTask(Task[BallMazeTaskConfig]):
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def add_task_mujoco(cfg: BallMazeTaskConfig, composer: ModelComposer, env_cfg: SimEnvCreatorConfig):
|
|
51
|
+
"""Add task-specific elements to the Mujoco scene."""
|
|
52
|
+
object2world = env_cfg.root_frame_to_world * cfg.object_center_to_root_frame
|
|
53
|
+
key = "maze{number}"
|
|
54
|
+
board_xml = rcs.OBJECT_PATHS[key.format(number=cfg.board_number)].format(number=cfg.board_number)
|
|
55
|
+
|
|
56
|
+
composer.add_object_world_frame(
|
|
57
|
+
board_xml,
|
|
58
|
+
object_prefix=cfg.object_body + "_",
|
|
59
|
+
pose=object2world,
|
|
60
|
+
register_root_relative_replay_free_joints=True,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def add_task_env(cfg: BallMazeTaskConfig, env: gym.Env, simulation: Sim, env_cfg: SimEnvCreatorConfig) -> gym.Env:
|
|
65
|
+
"""Add task-specific wrappers to the environment."""
|
|
66
|
+
_ = simulation
|
|
67
|
+
object2world = env_cfg.root_frame_to_world * cfg.object_center_to_root_frame
|
|
68
|
+
|
|
69
|
+
env = RandomSquareObjsPos(
|
|
70
|
+
env=env,
|
|
71
|
+
x_width=cfg.x_width,
|
|
72
|
+
y_width=cfg.y_width,
|
|
73
|
+
z_init=cfg.z_init,
|
|
74
|
+
center2world=object2world,
|
|
75
|
+
include_rotation=cfg.include_rotation,
|
|
76
|
+
obj_joint_names=cfg.object_joint_names,
|
|
77
|
+
obj_position_margin=cfg.obj_position_margin,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return BallMazeTaskWrapper(env, BallMazeStage(cfg))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class BallMazeStage(TaskStage):
|
|
84
|
+
def __init__(self, cfg: BallMazeTaskConfig):
|
|
85
|
+
super().__init__(
|
|
86
|
+
stage=0,
|
|
87
|
+
max_stage=4,
|
|
88
|
+
internal_state={
|
|
89
|
+
"left_arm_contact": False,
|
|
90
|
+
"right_arm_contact": False,
|
|
91
|
+
"both_arms_contact": False,
|
|
92
|
+
"maze_lifted": False,
|
|
93
|
+
"ball_left_start": False,
|
|
94
|
+
"ball_at_goal": False,
|
|
95
|
+
"ball_near_goal": False,
|
|
96
|
+
"maze_on_table": True,
|
|
97
|
+
"hands_released": False,
|
|
98
|
+
},
|
|
99
|
+
stage_to_subinstructions={
|
|
100
|
+
0: "make contact with the maze",
|
|
101
|
+
1: "grasp the maze with both arms",
|
|
102
|
+
2: "lift the maze and move the ball out of the start area",
|
|
103
|
+
3: "guide the ball from the start area toward the goal",
|
|
104
|
+
4: "task completed; the ball has reached the goal",
|
|
105
|
+
},
|
|
106
|
+
instruction=cfg.task_instructions,
|
|
107
|
+
)
|
|
108
|
+
self.cfg = cfg
|
|
109
|
+
self.start_marker = "board_marker_start"
|
|
110
|
+
self.goal_marker = "board_marker_goal"
|
|
111
|
+
self.ball_joint = "board_ball_joint"
|
|
112
|
+
self.ball_name = "board_ball"
|
|
113
|
+
self.board_name = "board_board"
|
|
114
|
+
self.start_dist = 0.02
|
|
115
|
+
self.goal_dist = 0.015
|
|
116
|
+
self.done_goal_radius = 0.10
|
|
117
|
+
self.table_height: float | None = None
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _planar_distance(a: np.ndarray, b: np.ndarray) -> float:
|
|
121
|
+
return float(np.linalg.norm(a[:2] - b[:2]))
|
|
122
|
+
|
|
123
|
+
def set_table_height(self, table_height: float):
|
|
124
|
+
self.table_height = table_height
|
|
125
|
+
|
|
126
|
+
def _board_body_ids(self, sim: Sim) -> set[int]:
|
|
127
|
+
body_ids: set[int] = set()
|
|
128
|
+
for body_id in range(sim.model.nbody):
|
|
129
|
+
body_name = sim.model.body(body_id).name
|
|
130
|
+
if body_name and body_name.startswith("board_") and body_name != self.ball_name:
|
|
131
|
+
body_ids.add(body_id)
|
|
132
|
+
return body_ids
|
|
133
|
+
|
|
134
|
+
def _maze_on_table(self, sim: Sim, board_body_ids: set[int], board_height: float) -> bool:
|
|
135
|
+
for i in range(sim.data.ncon):
|
|
136
|
+
contact = sim.data.contact[i]
|
|
137
|
+
geom1_name = sim.model.geom(contact.geom1).name
|
|
138
|
+
geom2_name = sim.model.geom(contact.geom2).name
|
|
139
|
+
body1_id = sim.model.geom_bodyid[contact.geom1]
|
|
140
|
+
body2_id = sim.model.geom_bodyid[contact.geom2]
|
|
141
|
+
|
|
142
|
+
if geom1_name == "floor" and body2_id in board_body_ids:
|
|
143
|
+
return True
|
|
144
|
+
if geom2_name == "floor" and body1_id in board_body_ids:
|
|
145
|
+
return True
|
|
146
|
+
|
|
147
|
+
if self.table_height is None:
|
|
148
|
+
return False
|
|
149
|
+
|
|
150
|
+
return board_height <= self.table_height + 0.02
|
|
151
|
+
|
|
152
|
+
def update_internal_state(self, sim: Sim):
|
|
153
|
+
board_body_id = mj.mj_name2id(sim.model, mj.mjtObj.mjOBJ_BODY, self.board_name)
|
|
154
|
+
board_body_ids = self._board_body_ids(sim)
|
|
155
|
+
left_arm_contacts = get_bodies_in_contact_with_gripper_pad(sim, "left")
|
|
156
|
+
right_arm_contacts = get_bodies_in_contact_with_gripper_pad(sim, "right")
|
|
157
|
+
|
|
158
|
+
left_arm_contact = any(body_id in board_body_ids for body_id in left_arm_contacts)
|
|
159
|
+
right_arm_contact = any(body_id in board_body_ids for body_id in right_arm_contacts)
|
|
160
|
+
both_arms_contact = left_arm_contact and right_arm_contact
|
|
161
|
+
any_arm_contact = left_arm_contact or right_arm_contact
|
|
162
|
+
|
|
163
|
+
board_height = sim.data.xpos[board_body_id][2]
|
|
164
|
+
maze_on_table = self._maze_on_table(sim, board_body_ids, board_height)
|
|
165
|
+
|
|
166
|
+
ball_pose = sim.data.body(self.ball_name).xpos
|
|
167
|
+
start_pose = sim.data.body(self.start_marker).xpos
|
|
168
|
+
goal_pose = sim.data.body(self.goal_marker).xpos
|
|
169
|
+
# The ball center is always above the board plane, so using full 3D distance
|
|
170
|
+
# can incorrectly mark "left start" immediately after reset.
|
|
171
|
+
start_dist = self._planar_distance(start_pose, ball_pose)
|
|
172
|
+
goal_dist = np.linalg.norm(goal_pose - ball_pose)
|
|
173
|
+
|
|
174
|
+
self.internal_state["left_arm_contact"] = left_arm_contact
|
|
175
|
+
self.internal_state["right_arm_contact"] = right_arm_contact
|
|
176
|
+
self.internal_state["both_arms_contact"] = both_arms_contact
|
|
177
|
+
self.internal_state["maze_lifted"] = both_arms_contact and not maze_on_table
|
|
178
|
+
self.internal_state["ball_left_start"] = bool(
|
|
179
|
+
self.internal_state["ball_left_start"] or start_dist >= self.start_dist
|
|
180
|
+
)
|
|
181
|
+
self.internal_state["ball_at_goal"] = bool(self.internal_state["ball_at_goal"] or goal_dist <= self.goal_dist)
|
|
182
|
+
self.internal_state["ball_near_goal"] = bool(goal_dist <= self.done_goal_radius)
|
|
183
|
+
self.internal_state["maze_on_table"] = maze_on_table
|
|
184
|
+
self.internal_state["hands_released"] = not any_arm_contact
|
|
185
|
+
|
|
186
|
+
self.update_stage()
|
|
187
|
+
|
|
188
|
+
def update_stage(self):
|
|
189
|
+
if self.internal_state["ball_at_goal"]:
|
|
190
|
+
self.stage = 4
|
|
191
|
+
elif self.internal_state["ball_left_start"]:
|
|
192
|
+
self.stage = 3
|
|
193
|
+
elif self.internal_state["maze_lifted"]:
|
|
194
|
+
self.stage = 2
|
|
195
|
+
elif self.internal_state["left_arm_contact"] or self.internal_state["right_arm_contact"]:
|
|
196
|
+
self.stage = 1
|
|
197
|
+
else:
|
|
198
|
+
self.stage = 0
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class BallMazeTaskWrapper(TaskStageWrapper):
|
|
202
|
+
"""
|
|
203
|
+
Wrapper to conduct the task-specific soft reset for the ball maze while using the shared task-stage interface.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(self, env: gym.Env, stage_tracker: BallMazeStage):
|
|
207
|
+
super().__init__(env, stage_tracker)
|
|
208
|
+
self.stage_tracker: BallMazeStage = stage_tracker
|
|
209
|
+
self.sim = self.env.get_wrapper_attr("sim")
|
|
210
|
+
self.start_marker = stage_tracker.start_marker
|
|
211
|
+
self.ball_joint = stage_tracker.ball_joint
|
|
212
|
+
|
|
213
|
+
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
|
|
214
|
+
obs, info = super().reset(seed=seed, options=options)
|
|
215
|
+
|
|
216
|
+
# The board pose is randomized by inner wrappers during reset, so place the
|
|
217
|
+
# ball only after the full reset chain has finished. Spawn the ball tangent
|
|
218
|
+
# to the maze floor and let the physics settle before the episode starts.
|
|
219
|
+
pos = self.sim.data.body(self.start_marker).xpos.copy()
|
|
220
|
+
ball_radius = 0.01
|
|
221
|
+
spawn_clearance = 0.001
|
|
222
|
+
pos[2] += ball_radius + spawn_clearance
|
|
223
|
+
|
|
224
|
+
joint_id = self.sim.model.joint(self.ball_joint).id
|
|
225
|
+
qpos_adr = self.sim.model.jnt_qposadr[joint_id]
|
|
226
|
+
qvel_adr = self.sim.model.jnt_dofadr[joint_id]
|
|
227
|
+
self.sim.data.qpos[qpos_adr : qpos_adr + 3] = pos
|
|
228
|
+
# Reset the free joint velocity so the ball does not inherit motion from the previous rollout.
|
|
229
|
+
self.sim.data.qvel[qvel_adr : qvel_adr + 6] = 0
|
|
230
|
+
|
|
231
|
+
mj.mj_forward(self.sim.model, self.sim.data)
|
|
232
|
+
for _ in range(10):
|
|
233
|
+
mj.mj_step(self.sim.model, self.sim.data)
|
|
234
|
+
self.stage_tracker.set_table_height(float(self.sim.data.body(self.stage_tracker.board_name).xpos[2]))
|
|
235
|
+
self.stage_tracker.update_internal_state(self.sim)
|
|
236
|
+
info.update(self.stage_tracker.info)
|
|
237
|
+
return obs, info
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
rcs.TASKS.update({"ball_maze": BallMazeTask})
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class BallMazeEnvConfig(VentionSceneFR3Duo):
|
|
244
|
+
|
|
245
|
+
def config(self) -> SimEnvCreatorConfig:
|
|
246
|
+
cfg = super().config()
|
|
247
|
+
cfg.task_cfg = BallMazeTaskConfig()
|
|
248
|
+
return cfg
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
gym.register(id="duobench/ball_maze", entry_point=BallMazeEnvConfig())
|