duobench 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
duobench/__init__.py ADDED
@@ -0,0 +1,119 @@
1
+ __version__ = "0.1.0"
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import rcs
8
+ from rcs import get_prefix
9
+ from rcs._core import common
10
+
11
+ DUOBENCH_GITHUB_ASSET_ARCHIVE_URL = "https://github.com/RobotControlStack/duobench/archive/refs/tags/{tag}.zip"
12
+
13
+ DUOBENCH_PREFIX = get_prefix(
14
+ os.environ.get("DUOBENCH_PREFIX"),
15
+ Path(__file__).resolve().parents[2],
16
+ Path.home() / ".duobench",
17
+ __version__,
18
+ DUOBENCH_GITHUB_ASSET_ARCHIVE_URL,
19
+ )
20
+
21
+ CAMERA_PATHS: dict[str, str] = {}
22
+ GRIPPER_PATHS: dict[str, str] = {}
23
+ OBJECT_PATHS: dict[str, str] = {
24
+ # transfer_reorient
25
+ "transfer_reorient_socket": "assets/objects/transfer_reorient/socket.xml",
26
+ "transfer_reorient_hex_peg": "assets/objects/transfer_reorient/hexpeg.xml",
27
+ "transfer_reorient_squere_peg": "assets/objects/transfer_reorient/squere_peg.xml",
28
+ "transfer_reorient_triangel_peg": "assets/objects/transfer_reorient/triangle_peg.xml",
29
+ # block_stacking
30
+ "block_stacking_big_red_cube": "assets/objects/block_balance/big_red_cube.xml",
31
+ "block_stacking_blue_rectangle": "assets/objects/block_balance/blue_rectangle.xml",
32
+ "block_stacking_green_rectangle": "assets/objects/block_balance/green_rectangle.xml",
33
+ "block_stacking_pink_beam": "assets/objects/block_balance/pink_beam.xml",
34
+ # bin_sort
35
+ "parallel_pick_white_bowl": "assets/objects/bin_sort/white_bowl.xml",
36
+ "parallel_pick_white_tri_cylinder": "assets/objects/bin_sort/white_tri_cylinder.xml",
37
+ "parallel_pick_white_box": "assets/objects/bin_sort/white_box.xml",
38
+ "parallel_pick_white_pent_cylinder": "assets/objects/bin_sort/white_pent_cylinder.xml",
39
+ "parallel_pick_white_hex_cylinder": "assets/objects/bin_sort/white_hex_cylinder.xml",
40
+ "parallel_pick_white_oct_cylinder": "assets/objects/bin_sort/white_oct_cylinder.xml",
41
+ "parallel_pick_white_dec_cylinder": "assets/objects/bin_sort/white_dec_cylinder.xml",
42
+ "parallel_pick_white_cylinder": "assets/objects/bin_sort/white_cylinder.xml",
43
+ "parallel_pick_black_bowl": "assets/objects/bin_sort/black_bowl.xml",
44
+ "parallel_pick_black_tri_cylinder": "assets/objects/bin_sort/black_tri_cylinder.xml",
45
+ "parallel_pick_black_box": "assets/objects/bin_sort/black_box.xml",
46
+ "parallel_pick_black_pent_cylinder": "assets/objects/bin_sort/black_pent_cylinder.xml",
47
+ "parallel_pick_black_hex_cylinder": "assets/objects/bin_sort/black_hex_cylinder.xml",
48
+ "parallel_pick_black_oct_cylinder": "assets/objects/bin_sort/black_oct_cylinder.xml",
49
+ "parallel_pick_black_dec_cylinder": "assets/objects/bin_sort/black_dec_cylinder.xml",
50
+ "parallel_pick_black_cylinder": "assets/objects/bin_sort/black_cylinder.xml",
51
+ "parallel_pick_blue_tri_cylinder": "assets/objects/bin_sort/blue_tri_cylinder.xml",
52
+ "parallel_pick_blue_box": "assets/objects/bin_sort/blue_box.xml",
53
+ "parallel_pick_blue_pent_cylinder": "assets/objects/bin_sort/blue_pent_cylinder.xml",
54
+ "parallel_pick_blue_hex_cylinder": "assets/objects/bin_sort/blue_hex_cylinder.xml",
55
+ "parallel_pick_blue_cylinder": "assets/objects/bin_sort/blue_cylinder.xml",
56
+ "parallel_pick_green_tri_cylinder": "assets/objects/bin_sort/green_tri_cylinder.xml",
57
+ "parallel_pick_green_box": "assets/objects/bin_sort/green_box.xml",
58
+ "parallel_pick_green_pent_cylinder": "assets/objects/bin_sort/green_pent_cylinder.xml",
59
+ "parallel_pick_green_hex_cylinder": "assets/objects/bin_sort/green_hex_cylinder.xml",
60
+ "parallel_pick_green_cylinder": "assets/objects/bin_sort/green_cylinder.xml",
61
+ # spring_door
62
+ "spring_door_gray_microwave": "assets/objects/spring_door/gray_microwave.xml",
63
+ "spring_door_mug": "assets/objects/spring_door/mug.xml",
64
+ "spring_door_simple_mug": "assets/objects/spring_door/simple_mug.xml",
65
+ # hinge chest
66
+ "hinge_chest_chest": "assets/objects/hinge_chest/chest.xml",
67
+ # transfer_gate
68
+ "handover_hole_stand": "assets/objects/transfer_gate/stand.xml",
69
+ "handover_hole_mat": "assets/objects/transfer_gate/mat.xml",
70
+ "handover_hole_white_long_box": "assets/objects/transfer_gate/white_long_box.xml",
71
+ # carry pot
72
+ "carry_pot_pot": "assets/objects/carry_pot/pot.xml",
73
+ "carry_pot_stove": "assets/objects/carry_pot/stove.xml",
74
+ # join_blocks
75
+ "h_block": "assets/objects/join_blocks/join_blocks_h_block.xml",
76
+ "p_block": "assets/objects/join_blocks/join_blocks_p_block.xml",
77
+ "wall": "assets/objects/join_blocks/join_blocks_wall.xml",
78
+ # pour_marbles
79
+ "marble": "assets/objects/pour_marbles/marble.xml",
80
+ "teacup": "assets/objects/pour_marbles/teacup.xml",
81
+ # ball_maze
82
+ "maze1": "assets/objects/ball_maze/maze_simple_1.xml",
83
+ "maze2": "assets/objects/ball_maze/maze_simple_2.xml",
84
+ "maze3": "assets/objects/ball_maze/maze_simple_3.xml",
85
+ "maze4": "assets/objects/ball_maze/maze_simple_4.xml",
86
+ "maze5": "assets/objects/ball_maze/maze_simple_5.xml",
87
+ "maze6": "assets/objects/ball_maze/maze_simple_6.xml",
88
+ "maze7": "assets/objects/ball_maze/maze_simple_7.xml",
89
+ "maze8": "assets/objects/ball_maze/maze_simple_8.xml",
90
+ "maze9": "assets/objects/ball_maze/maze_simple_9.xml",
91
+ "maze10": "assets/objects/ball_maze/maze_simple_10.xml",
92
+ }
93
+
94
+
95
+ SCENE_PATHS = {
96
+ "empty_world": "assets/scenes/base_world.xml",
97
+ "vention_world": "assets/scenes/vention_table/vention_world.xml",
98
+ }
99
+
100
+ # Append DUOBENCH package prefix to all asset paths
101
+ for path_dict in (GRIPPER_PATHS, SCENE_PATHS, OBJECT_PATHS, CAMERA_PATHS):
102
+ for name, path in path_dict.items():
103
+ abs_path = os.path.join(DUOBENCH_PREFIX, path)
104
+ if not os.path.isfile(abs_path):
105
+ error_msg = f"Asset {name} not found at path: {abs_path}. Please make sure to download the assets."
106
+ raise FileNotFoundError(error_msg)
107
+ else:
108
+ path_dict[name] = abs_path
109
+
110
+ # Update the global RCS registries with our assets and tasks
111
+ rcs.SCENE_PATHS.update(SCENE_PATHS)
112
+ rcs.OBJECT_PATHS.update(OBJECT_PATHS)
113
+ rcs.GRIPPER_PATHS.update(GRIPPER_PATHS)
114
+ rcs.CAMERA_PATHS.update(CAMERA_PATHS)
115
+
116
+
117
+ rcs.DEFAULT_TRANSFORMS.update(
118
+ {"VENTION_HEIGHT_OFFSET": common.Pose(translation=np.array([0, 0, 0.8095]), quaternion=np.array([0, 0, 0, 1]))}
119
+ )
duobench/__main__.py ADDED
@@ -0,0 +1,20 @@
1
+ import pkgutil
2
+ from importlib import import_module
3
+
4
+ import duobench.tasks
5
+
6
+
7
+ def _register_duobench_tasks() -> None:
8
+ for module in pkgutil.iter_modules(duobench.tasks.__path__):
9
+ import_module(f"{duobench.tasks.__name__}.{module.name}")
10
+
11
+
12
+ def main() -> None:
13
+ # Import the base RCS CLI only after DuoBench tasks are registered so the
14
+ # existing commands can resolve `duobench/*` env ids.
15
+ _register_duobench_tasks()
16
+ import_module("rcs.__main__").app()
17
+
18
+
19
+ if __name__ == "__main__":
20
+ main()
@@ -0,0 +1,81 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass, field
3
+ from typing import Any
4
+
5
+ import gymnasium as gym
6
+ from rcs import sim as simulation
7
+
8
+
9
+ @dataclass(kw_only=True)
10
+ class TaskStage(ABC):
11
+ stage: int = 0
12
+ max_stage: int = 0
13
+ internal_state: dict[str, bool] = field(default_factory=dict)
14
+ stage_to_subinstructions: dict[int, str] = field(default_factory=dict)
15
+ instruction: str = ""
16
+
17
+ @abstractmethod
18
+ def update_stage(self):
19
+ msg = "TaskStage is an abstract class. Please implement the update_stage method in a subclass."
20
+ raise NotImplementedError(msg)
21
+
22
+ @abstractmethod
23
+ def update_internal_state(self, sim: simulation.Sim):
24
+ msg = "TaskStage is an abstract class. Please implement the update_internal_state method in a subclass."
25
+ raise NotImplementedError(msg)
26
+
27
+ def reset(self):
28
+ self.stage = 0
29
+ for key in self.internal_state:
30
+ self.internal_state[key] = False
31
+
32
+ @property
33
+ def current_subinstruction(self) -> str:
34
+ return self.stage_to_subinstructions[self.stage] if self.stage_to_subinstructions is not None else None
35
+
36
+ @property
37
+ def success(self) -> bool:
38
+ return self.stage == self.max_stage
39
+
40
+ @property
41
+ def normalized_stage(self) -> float:
42
+ return self.stage / self.max_stage
43
+
44
+ @property
45
+ def info(self) -> dict[str, Any]:
46
+ return {
47
+ "success": self.success,
48
+ "stage": self.stage,
49
+ "max_stage": self.max_stage,
50
+ "current_subinstruction": self.current_subinstruction,
51
+ "stage_to_subinstructions": {
52
+ str(stage): instruction for stage, instruction in self.stage_to_subinstructions.items()
53
+ },
54
+ "instruction": self.instruction,
55
+ }
56
+
57
+
58
+ class TaskStageWrapper(gym.Wrapper):
59
+
60
+ def __init__(
61
+ self,
62
+ env: gym.Env,
63
+ stage_tracker: TaskStage,
64
+ ):
65
+ super().__init__(env)
66
+ self.stage_tracker = stage_tracker
67
+
68
+ def step(self, action: dict[str, Any]):
69
+ obs, _, _, truncated, info = super().step(action)
70
+ self.stage_tracker.update_internal_state(self.get_wrapper_attr("sim"))
71
+
72
+ info.update(self.stage_tracker.info)
73
+ reward = self.stage_tracker.normalized_stage
74
+
75
+ return obs, reward, self.stage_tracker.success, truncated, info
76
+
77
+ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
78
+ self.stage_tracker.reset()
79
+ obs, info = super().reset(seed=seed, options=options)
80
+ info.update(self.stage_tracker.info)
81
+ return obs, info
@@ -0,0 +1,251 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any
3
+
4
+ import gymnasium as gym
5
+ import mujoco as mj
6
+ import numpy as np
7
+ import rcs
8
+ from rcs.envs.scenes import BaseTaskConfig, SimEnvCreatorConfig, Task
9
+ from rcs.sim.composer import ModelComposer
10
+ from rcs.sim.sim import Sim
11
+
12
+ from duobench.tasks import TaskStage, TaskStageWrapper
13
+ from duobench.utils.helper_wrappers import (
14
+ RandomSquareObjsPos,
15
+ get_bodies_in_contact_with_gripper_pad,
16
+ )
17
+ from duobench.utils.vention_config import VentionSceneFR3Duo
18
+
19
+
20
+ @dataclass(kw_only=True)
21
+ class BallMazeTaskConfig(BaseTaskConfig):
22
+
23
+ task_id: str = "ball_maze"
24
+
25
+ board_number = 2
26
+ object_body = "board"
27
+ include_rotation: bool = True
28
+ task_instructions = "pick up the board and tilt it so the ball roles onto the red square"
29
+ hard_reset = False
30
+ x_width = 0.3
31
+ y_width = 0.3
32
+ z_init = 0.02
33
+ obj_position_margin = 0.08
34
+ object_joint_names: list[str] = field(default_factory=lambda: ["board_board_joint"])
35
+ object2root_frame: rcs.common.Pose = field(
36
+ default_factory=lambda: rcs.common.Pose(translation=np.array([0.5, 0, 0.05]), quaternion=np.array([0, 0, 0, 1]))
37
+ )
38
+ object_center_to_root_frame: rcs.common.Pose = field(
39
+ default_factory=lambda: rcs.common.Pose(
40
+ # 0.02 = z_init
41
+ translation=np.array([0.5, 0.0, 0.02]),
42
+ quaternion=np.array([0, 0, 0, 1]),
43
+ )
44
+ )
45
+
46
+
47
+ class BallMazeTask(Task[BallMazeTaskConfig]):
48
+
49
+ @staticmethod
50
+ def add_task_mujoco(cfg: BallMazeTaskConfig, composer: ModelComposer, env_cfg: SimEnvCreatorConfig):
51
+ """Add task-specific elements to the Mujoco scene."""
52
+ object2world = env_cfg.root_frame_to_world * cfg.object_center_to_root_frame
53
+ key = "maze{number}"
54
+ board_xml = rcs.OBJECT_PATHS[key.format(number=cfg.board_number)].format(number=cfg.board_number)
55
+
56
+ composer.add_object_world_frame(
57
+ board_xml,
58
+ object_prefix=cfg.object_body + "_",
59
+ pose=object2world,
60
+ register_root_relative_replay_free_joints=True,
61
+ )
62
+
63
+ @staticmethod
64
+ def add_task_env(cfg: BallMazeTaskConfig, env: gym.Env, simulation: Sim, env_cfg: SimEnvCreatorConfig) -> gym.Env:
65
+ """Add task-specific wrappers to the environment."""
66
+ _ = simulation
67
+ object2world = env_cfg.root_frame_to_world * cfg.object_center_to_root_frame
68
+
69
+ env = RandomSquareObjsPos(
70
+ env=env,
71
+ x_width=cfg.x_width,
72
+ y_width=cfg.y_width,
73
+ z_init=cfg.z_init,
74
+ center2world=object2world,
75
+ include_rotation=cfg.include_rotation,
76
+ obj_joint_names=cfg.object_joint_names,
77
+ obj_position_margin=cfg.obj_position_margin,
78
+ )
79
+
80
+ return BallMazeTaskWrapper(env, BallMazeStage(cfg))
81
+
82
+
83
+ class BallMazeStage(TaskStage):
84
+ def __init__(self, cfg: BallMazeTaskConfig):
85
+ super().__init__(
86
+ stage=0,
87
+ max_stage=4,
88
+ internal_state={
89
+ "left_arm_contact": False,
90
+ "right_arm_contact": False,
91
+ "both_arms_contact": False,
92
+ "maze_lifted": False,
93
+ "ball_left_start": False,
94
+ "ball_at_goal": False,
95
+ "ball_near_goal": False,
96
+ "maze_on_table": True,
97
+ "hands_released": False,
98
+ },
99
+ stage_to_subinstructions={
100
+ 0: "make contact with the maze",
101
+ 1: "grasp the maze with both arms",
102
+ 2: "lift the maze and move the ball out of the start area",
103
+ 3: "guide the ball from the start area toward the goal",
104
+ 4: "task completed; the ball has reached the goal",
105
+ },
106
+ instruction=cfg.task_instructions,
107
+ )
108
+ self.cfg = cfg
109
+ self.start_marker = "board_marker_start"
110
+ self.goal_marker = "board_marker_goal"
111
+ self.ball_joint = "board_ball_joint"
112
+ self.ball_name = "board_ball"
113
+ self.board_name = "board_board"
114
+ self.start_dist = 0.02
115
+ self.goal_dist = 0.015
116
+ self.done_goal_radius = 0.10
117
+ self.table_height: float | None = None
118
+
119
+ @staticmethod
120
+ def _planar_distance(a: np.ndarray, b: np.ndarray) -> float:
121
+ return float(np.linalg.norm(a[:2] - b[:2]))
122
+
123
+ def set_table_height(self, table_height: float):
124
+ self.table_height = table_height
125
+
126
+ def _board_body_ids(self, sim: Sim) -> set[int]:
127
+ body_ids: set[int] = set()
128
+ for body_id in range(sim.model.nbody):
129
+ body_name = sim.model.body(body_id).name
130
+ if body_name and body_name.startswith("board_") and body_name != self.ball_name:
131
+ body_ids.add(body_id)
132
+ return body_ids
133
+
134
+ def _maze_on_table(self, sim: Sim, board_body_ids: set[int], board_height: float) -> bool:
135
+ for i in range(sim.data.ncon):
136
+ contact = sim.data.contact[i]
137
+ geom1_name = sim.model.geom(contact.geom1).name
138
+ geom2_name = sim.model.geom(contact.geom2).name
139
+ body1_id = sim.model.geom_bodyid[contact.geom1]
140
+ body2_id = sim.model.geom_bodyid[contact.geom2]
141
+
142
+ if geom1_name == "floor" and body2_id in board_body_ids:
143
+ return True
144
+ if geom2_name == "floor" and body1_id in board_body_ids:
145
+ return True
146
+
147
+ if self.table_height is None:
148
+ return False
149
+
150
+ return board_height <= self.table_height + 0.02
151
+
152
+ def update_internal_state(self, sim: Sim):
153
+ board_body_id = mj.mj_name2id(sim.model, mj.mjtObj.mjOBJ_BODY, self.board_name)
154
+ board_body_ids = self._board_body_ids(sim)
155
+ left_arm_contacts = get_bodies_in_contact_with_gripper_pad(sim, "left")
156
+ right_arm_contacts = get_bodies_in_contact_with_gripper_pad(sim, "right")
157
+
158
+ left_arm_contact = any(body_id in board_body_ids for body_id in left_arm_contacts)
159
+ right_arm_contact = any(body_id in board_body_ids for body_id in right_arm_contacts)
160
+ both_arms_contact = left_arm_contact and right_arm_contact
161
+ any_arm_contact = left_arm_contact or right_arm_contact
162
+
163
+ board_height = sim.data.xpos[board_body_id][2]
164
+ maze_on_table = self._maze_on_table(sim, board_body_ids, board_height)
165
+
166
+ ball_pose = sim.data.body(self.ball_name).xpos
167
+ start_pose = sim.data.body(self.start_marker).xpos
168
+ goal_pose = sim.data.body(self.goal_marker).xpos
169
+ # The ball center is always above the board plane, so using full 3D distance
170
+ # can incorrectly mark "left start" immediately after reset.
171
+ start_dist = self._planar_distance(start_pose, ball_pose)
172
+ goal_dist = np.linalg.norm(goal_pose - ball_pose)
173
+
174
+ self.internal_state["left_arm_contact"] = left_arm_contact
175
+ self.internal_state["right_arm_contact"] = right_arm_contact
176
+ self.internal_state["both_arms_contact"] = both_arms_contact
177
+ self.internal_state["maze_lifted"] = both_arms_contact and not maze_on_table
178
+ self.internal_state["ball_left_start"] = bool(
179
+ self.internal_state["ball_left_start"] or start_dist >= self.start_dist
180
+ )
181
+ self.internal_state["ball_at_goal"] = bool(self.internal_state["ball_at_goal"] or goal_dist <= self.goal_dist)
182
+ self.internal_state["ball_near_goal"] = bool(goal_dist <= self.done_goal_radius)
183
+ self.internal_state["maze_on_table"] = maze_on_table
184
+ self.internal_state["hands_released"] = not any_arm_contact
185
+
186
+ self.update_stage()
187
+
188
+ def update_stage(self):
189
+ if self.internal_state["ball_at_goal"]:
190
+ self.stage = 4
191
+ elif self.internal_state["ball_left_start"]:
192
+ self.stage = 3
193
+ elif self.internal_state["maze_lifted"]:
194
+ self.stage = 2
195
+ elif self.internal_state["left_arm_contact"] or self.internal_state["right_arm_contact"]:
196
+ self.stage = 1
197
+ else:
198
+ self.stage = 0
199
+
200
+
201
+ class BallMazeTaskWrapper(TaskStageWrapper):
202
+ """
203
+ Wrapper to conduct the task-specific soft reset for the ball maze while using the shared task-stage interface.
204
+ """
205
+
206
+ def __init__(self, env: gym.Env, stage_tracker: BallMazeStage):
207
+ super().__init__(env, stage_tracker)
208
+ self.stage_tracker: BallMazeStage = stage_tracker
209
+ self.sim = self.env.get_wrapper_attr("sim")
210
+ self.start_marker = stage_tracker.start_marker
211
+ self.ball_joint = stage_tracker.ball_joint
212
+
213
+ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
214
+ obs, info = super().reset(seed=seed, options=options)
215
+
216
+ # The board pose is randomized by inner wrappers during reset, so place the
217
+ # ball only after the full reset chain has finished. Spawn the ball tangent
218
+ # to the maze floor and let the physics settle before the episode starts.
219
+ pos = self.sim.data.body(self.start_marker).xpos.copy()
220
+ ball_radius = 0.01
221
+ spawn_clearance = 0.001
222
+ pos[2] += ball_radius + spawn_clearance
223
+
224
+ joint_id = self.sim.model.joint(self.ball_joint).id
225
+ qpos_adr = self.sim.model.jnt_qposadr[joint_id]
226
+ qvel_adr = self.sim.model.jnt_dofadr[joint_id]
227
+ self.sim.data.qpos[qpos_adr : qpos_adr + 3] = pos
228
+ # Reset the free joint velocity so the ball does not inherit motion from the previous rollout.
229
+ self.sim.data.qvel[qvel_adr : qvel_adr + 6] = 0
230
+
231
+ mj.mj_forward(self.sim.model, self.sim.data)
232
+ for _ in range(10):
233
+ mj.mj_step(self.sim.model, self.sim.data)
234
+ self.stage_tracker.set_table_height(float(self.sim.data.body(self.stage_tracker.board_name).xpos[2]))
235
+ self.stage_tracker.update_internal_state(self.sim)
236
+ info.update(self.stage_tracker.info)
237
+ return obs, info
238
+
239
+
240
+ rcs.TASKS.update({"ball_maze": BallMazeTask})
241
+
242
+
243
+ class BallMazeEnvConfig(VentionSceneFR3Duo):
244
+
245
+ def config(self) -> SimEnvCreatorConfig:
246
+ cfg = super().config()
247
+ cfg.task_cfg = BallMazeTaskConfig()
248
+ return cfg
249
+
250
+
251
+ gym.register(id="duobench/ball_maze", entry_point=BallMazeEnvConfig())