opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
opentau/envs/libero.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
r"""This module provides an environment wrapper for LIBERO tasks."""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
23
|
+
from functools import partial
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
import gymnasium as gym
|
|
28
|
+
import numpy as np
|
|
29
|
+
import torch
|
|
30
|
+
from gymnasium import spaces
|
|
31
|
+
from libero.libero import benchmark, get_libero_path
|
|
32
|
+
from libero.libero.envs import OffScreenRenderEnv
|
|
33
|
+
from robosuite.utils.transform_utils import quat2axisangle
|
|
34
|
+
|
|
35
|
+
from opentau.utils.accelerate_utils import acc_print
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
|
39
|
+
"""Normalize camera_name into a non-empty list of strings."""
|
|
40
|
+
if isinstance(camera_name, str):
|
|
41
|
+
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
|
42
|
+
elif isinstance(camera_name, (list, tuple)):
|
|
43
|
+
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
|
44
|
+
else:
|
|
45
|
+
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
|
46
|
+
if not cams:
|
|
47
|
+
raise ValueError("camera_name resolved to an empty list.")
|
|
48
|
+
return cams
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_suite(name: str) -> benchmark.Benchmark:
|
|
52
|
+
"""Instantiate a LIBERO suite by name with clear validation."""
|
|
53
|
+
bench = benchmark.get_benchmark_dict()
|
|
54
|
+
if name not in bench:
|
|
55
|
+
raise ValueError(f"Unknown LIBERO suite '{name}'. Available: {', '.join(sorted(bench.keys()))}")
|
|
56
|
+
suite = bench[name]()
|
|
57
|
+
if not getattr(suite, "tasks", None):
|
|
58
|
+
raise ValueError(f"Suite '{name}' has no tasks.")
|
|
59
|
+
return suite
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
|
|
63
|
+
"""Validate/normalize task ids. If None → all tasks."""
|
|
64
|
+
if task_ids is None:
|
|
65
|
+
return list(range(total_tasks))
|
|
66
|
+
ids = sorted({int(t) for t in task_ids})
|
|
67
|
+
for t in ids:
|
|
68
|
+
if t < 0 or t >= total_tasks:
|
|
69
|
+
raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
|
|
70
|
+
return ids
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
|
|
74
|
+
init_states_path = (
|
|
75
|
+
Path(get_libero_path("init_states"))
|
|
76
|
+
/ task_suite.tasks[i].problem_folder
|
|
77
|
+
/ task_suite.tasks[i].init_states_file
|
|
78
|
+
)
|
|
79
|
+
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
|
80
|
+
return init_states
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_libero_dummy_action() -> list[float | int]:
|
|
84
|
+
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
|
|
85
|
+
return [0, 0, 0, 0, 0, 0, -1]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
OBS_STATE_DIM = 8
|
|
89
|
+
ACTION_DIM = 7
|
|
90
|
+
AGENT_POS_LOW = -1000.0
|
|
91
|
+
AGENT_POS_HIGH = 1000.0
|
|
92
|
+
ACTION_LOW = -1.0
|
|
93
|
+
ACTION_HIGH = 1.0
|
|
94
|
+
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
|
95
|
+
"libero_spatial": 280, # longest training demo has 193 steps
|
|
96
|
+
"libero_object": 280, # longest training demo has 254 steps
|
|
97
|
+
"libero_goal": 300, # longest training demo has 270 steps
|
|
98
|
+
"libero_10": 520, # longest training demo has 505 steps
|
|
99
|
+
"libero_90": 400, # longest training demo has 373 steps
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class LiberoEnv(gym.Env):
|
|
104
|
+
r"""Environment wrapper for LIBERO tasks."""
|
|
105
|
+
|
|
106
|
+
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
task_suite: Any,
|
|
111
|
+
task_id: int,
|
|
112
|
+
task_suite_name: str,
|
|
113
|
+
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
|
114
|
+
obs_type: str = "pixels_agent_pos",
|
|
115
|
+
render_mode: str = "rgb_array",
|
|
116
|
+
observation_width: int = 256,
|
|
117
|
+
observation_height: int = 256,
|
|
118
|
+
visualization_width: int = 640,
|
|
119
|
+
visualization_height: int = 480,
|
|
120
|
+
init_states: bool = True,
|
|
121
|
+
episode_index: int = 0,
|
|
122
|
+
camera_name_mapping: dict[str, list[str]] | None = None,
|
|
123
|
+
num_steps_wait: int = 10,
|
|
124
|
+
render_cam: str | None = None,
|
|
125
|
+
):
|
|
126
|
+
r"""Initialize the LiberoEnv.
|
|
127
|
+
Args:
|
|
128
|
+
task_suite: The LIBERO task suite to use.
|
|
129
|
+
task_id: The ID of the task within the suite.
|
|
130
|
+
task_suite_name: The name of the task suite.
|
|
131
|
+
camera_name: The name(s) of the camera(s) to use for observations. If a string, can be comma-separated.
|
|
132
|
+
obs_type: The type of observation to return. Options are 'pixels' or 'pixels
|
|
133
|
+
render_mode: The render mode for the environment.
|
|
134
|
+
observation_width: The width of the observation images.
|
|
135
|
+
observation_height: The height of the observation images.
|
|
136
|
+
visualization_width: The width of the visualization window.
|
|
137
|
+
visualization_height: The height of the visualization window.
|
|
138
|
+
init_states: Whether to use predefined initial states for the tasks.
|
|
139
|
+
episode_index: The index of the episode for selecting initial states.
|
|
140
|
+
camera_name_mapping: Optional mapping from raw camera names to desired observation keys.
|
|
141
|
+
num_steps_wait: Number of no-op steps to take after reset to stabilize the environment.
|
|
142
|
+
render_cam: The camera name to use for rendering. If None, uses the first camera.
|
|
143
|
+
"""
|
|
144
|
+
super().__init__()
|
|
145
|
+
self.task_id = task_id
|
|
146
|
+
self.obs_type = obs_type
|
|
147
|
+
self.render_mode = render_mode
|
|
148
|
+
self.observation_width = observation_width
|
|
149
|
+
self.observation_height = observation_height
|
|
150
|
+
self.visualization_width = visualization_width
|
|
151
|
+
self.visualization_height = visualization_height
|
|
152
|
+
self.init_states = init_states
|
|
153
|
+
self.camera_name = _parse_camera_names(
|
|
154
|
+
camera_name
|
|
155
|
+
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
|
|
156
|
+
self.render_cam = render_cam
|
|
157
|
+
|
|
158
|
+
# Map raw camera names to "image1" and "image2".
|
|
159
|
+
# The preprocessing step `preprocess_observation` will then prefix these with `.images.*`,
|
|
160
|
+
# following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`).
|
|
161
|
+
# This ensures the policy consistently receives observations in the
|
|
162
|
+
# expected format regardless of the original camera naming.
|
|
163
|
+
if camera_name_mapping is None:
|
|
164
|
+
camera_name_mapping = {
|
|
165
|
+
"agentview_image": ["camera0"],
|
|
166
|
+
"robot0_eye_in_hand_image": ["camera1"],
|
|
167
|
+
}
|
|
168
|
+
self.camera_name_mapping = camera_name_mapping
|
|
169
|
+
for cam in self.camera_name_mapping:
|
|
170
|
+
assert not isinstance(self.camera_name_mapping[cam], str), (
|
|
171
|
+
"camera_name_mapping values must be lists of strings; "
|
|
172
|
+
f"got string {self.camera_name_mapping[cam]} for {cam} instead"
|
|
173
|
+
)
|
|
174
|
+
self.num_steps_wait = num_steps_wait
|
|
175
|
+
self.episode_index = episode_index
|
|
176
|
+
# Load once and keep
|
|
177
|
+
self._init_states = _get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
|
178
|
+
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
|
179
|
+
|
|
180
|
+
self._env = self._make_envs_task(task_suite, self.task_id)
|
|
181
|
+
default_steps = 500
|
|
182
|
+
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
|
183
|
+
|
|
184
|
+
images = {}
|
|
185
|
+
for cam in self.camera_name:
|
|
186
|
+
for mapped_cam in self.camera_name_mapping[cam]:
|
|
187
|
+
images[mapped_cam] = spaces.Box(
|
|
188
|
+
low=0,
|
|
189
|
+
high=255,
|
|
190
|
+
shape=(self.observation_height, self.observation_width, 3),
|
|
191
|
+
dtype=np.uint8,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if self.obs_type == "state":
|
|
195
|
+
raise NotImplementedError(
|
|
196
|
+
"The 'state' observation type is not supported in LiberoEnv. "
|
|
197
|
+
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
elif self.obs_type == "pixels":
|
|
201
|
+
self.observation_space = spaces.Dict(
|
|
202
|
+
{
|
|
203
|
+
"pixels": spaces.Dict(images),
|
|
204
|
+
}
|
|
205
|
+
)
|
|
206
|
+
elif self.obs_type == "pixels_agent_pos":
|
|
207
|
+
self.observation_space = spaces.Dict(
|
|
208
|
+
{
|
|
209
|
+
"pixels": spaces.Dict(images),
|
|
210
|
+
"agent_pos": spaces.Box(
|
|
211
|
+
low=AGENT_POS_LOW,
|
|
212
|
+
high=AGENT_POS_HIGH,
|
|
213
|
+
shape=(OBS_STATE_DIM,),
|
|
214
|
+
dtype=np.float64,
|
|
215
|
+
),
|
|
216
|
+
}
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
self.action_space = spaces.Box(
|
|
220
|
+
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def render(self) -> np.ndarray:
|
|
224
|
+
r"""Render the environment and return a numpy array representing the RGB camera.
|
|
225
|
+
If `self.render_cam` is set, use that camera; otherwise, use the first camera."""
|
|
226
|
+
raw_obs = self._env.env._get_observations()
|
|
227
|
+
cams: dict[str, np.ndarray] = self._format_raw_obs(raw_obs)["pixels"]
|
|
228
|
+
# if `self.render_cam` is not set, use the first camera
|
|
229
|
+
render_cam = self.render_cam or next(iter(cams))
|
|
230
|
+
return cams[render_cam]
|
|
231
|
+
|
|
232
|
+
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
|
|
233
|
+
task = task_suite.get_task(task_id)
|
|
234
|
+
self.task = task.name
|
|
235
|
+
self.task_description = task.language
|
|
236
|
+
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
|
237
|
+
|
|
238
|
+
env_args = {
|
|
239
|
+
"bddl_file_name": task_bddl_file,
|
|
240
|
+
"camera_heights": self.observation_height,
|
|
241
|
+
"camera_widths": self.observation_width,
|
|
242
|
+
}
|
|
243
|
+
env = OffScreenRenderEnv(**env_args)
|
|
244
|
+
env.reset()
|
|
245
|
+
return env
|
|
246
|
+
|
|
247
|
+
def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]:
|
|
248
|
+
images = {}
|
|
249
|
+
for camera_name in self.camera_name:
|
|
250
|
+
image = raw_obs[camera_name]
|
|
251
|
+
image = image[::-1, ::-1] # rotate 180 degrees
|
|
252
|
+
for mapped_cam in self.camera_name_mapping[camera_name]:
|
|
253
|
+
images[mapped_cam] = image.copy()
|
|
254
|
+
state = np.concatenate(
|
|
255
|
+
(
|
|
256
|
+
raw_obs["robot0_eef_pos"],
|
|
257
|
+
quat2axisangle(raw_obs["robot0_eef_quat"]),
|
|
258
|
+
raw_obs["robot0_gripper_qpos"],
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
agent_pos = state
|
|
262
|
+
if self.obs_type == "pixels":
|
|
263
|
+
return {"pixels": images.copy()}
|
|
264
|
+
if self.obs_type == "pixels_agent_pos":
|
|
265
|
+
return {
|
|
266
|
+
"pixels": images.copy(),
|
|
267
|
+
"agent_pos": agent_pos,
|
|
268
|
+
}
|
|
269
|
+
raise NotImplementedError(
|
|
270
|
+
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
|
|
271
|
+
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def reset(self, seed=None, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
275
|
+
r"""Reset the environment with the given seed.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
seed: The seed to use for resetting the environment.
|
|
279
|
+
Returns:
|
|
280
|
+
observation: The initial observation after reset.
|
|
281
|
+
info: Additional information about the reset.
|
|
282
|
+
"""
|
|
283
|
+
super().reset(seed=seed)
|
|
284
|
+
self._env.seed(seed)
|
|
285
|
+
if self.init_states and self._init_states is not None:
|
|
286
|
+
self._env.set_init_state(self._init_states[self._init_state_id])
|
|
287
|
+
raw_obs = self._env.reset()
|
|
288
|
+
|
|
289
|
+
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
|
290
|
+
# Step the simulator with a no-op action for a few frames so everything settles.
|
|
291
|
+
# Increasing this value can improve determinism and reproducibility across resets.
|
|
292
|
+
for _ in range(self.num_steps_wait):
|
|
293
|
+
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
|
294
|
+
observation = self._format_raw_obs(raw_obs)
|
|
295
|
+
info = {"is_success": False}
|
|
296
|
+
return observation, info
|
|
297
|
+
|
|
298
|
+
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
|
299
|
+
r"""Take a step in the environment with the given action.
|
|
300
|
+
Args:
|
|
301
|
+
action: The action to take.
|
|
302
|
+
Returns:
|
|
303
|
+
observation: The observation after taking the step.
|
|
304
|
+
reward: The reward obtained from taking the step.
|
|
305
|
+
terminated: Whether the episode has terminated.
|
|
306
|
+
truncated: Whether the episode was truncated.
|
|
307
|
+
info: Additional information about the step.
|
|
308
|
+
"""
|
|
309
|
+
if action.ndim != 1:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
f"Expected action to be 1-D (shape (action_dim,)), "
|
|
312
|
+
f"but got shape {action.shape} with ndim={action.ndim}"
|
|
313
|
+
)
|
|
314
|
+
if len(action) > ACTION_DIM:
|
|
315
|
+
action = action[:ACTION_DIM]
|
|
316
|
+
raw_obs, reward, done, info = self._env.step(action)
|
|
317
|
+
|
|
318
|
+
is_success = self._env.check_success()
|
|
319
|
+
terminated = done or is_success
|
|
320
|
+
info["is_success"] = is_success
|
|
321
|
+
|
|
322
|
+
observation = self._format_raw_obs(raw_obs)
|
|
323
|
+
if done:
|
|
324
|
+
self.reset()
|
|
325
|
+
info.update(
|
|
326
|
+
{
|
|
327
|
+
"task": self.task,
|
|
328
|
+
"task_id": self.task_id,
|
|
329
|
+
"done": done,
|
|
330
|
+
"is_success": is_success,
|
|
331
|
+
}
|
|
332
|
+
)
|
|
333
|
+
truncated = False
|
|
334
|
+
return observation, reward, terminated, truncated, info
|
|
335
|
+
|
|
336
|
+
def close(self):
|
|
337
|
+
r"""Close the environment and release any resources."""
|
|
338
|
+
self._env.close()
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def _make_env_fns(
|
|
342
|
+
*,
|
|
343
|
+
suite,
|
|
344
|
+
suite_name: str,
|
|
345
|
+
task_id: int,
|
|
346
|
+
n_envs: int,
|
|
347
|
+
camera_names: list[str],
|
|
348
|
+
init_states: bool,
|
|
349
|
+
gym_kwargs: Mapping[str, Any],
|
|
350
|
+
) -> list[Callable[[], LiberoEnv]]:
|
|
351
|
+
"""Build n_envs factory callables for a single (suite, task_id)."""
|
|
352
|
+
|
|
353
|
+
def _make_env(episode_index: int, **kwargs) -> LiberoEnv:
|
|
354
|
+
local_kwargs = dict(kwargs)
|
|
355
|
+
return LiberoEnv(
|
|
356
|
+
task_suite=suite,
|
|
357
|
+
task_id=task_id,
|
|
358
|
+
task_suite_name=suite_name,
|
|
359
|
+
camera_name=camera_names,
|
|
360
|
+
init_states=init_states,
|
|
361
|
+
episode_index=episode_index,
|
|
362
|
+
**local_kwargs,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
fns: list[Callable[[], LiberoEnv]] = []
|
|
366
|
+
for episode_index in range(n_envs):
|
|
367
|
+
fns.append(partial(_make_env, episode_index, **gym_kwargs))
|
|
368
|
+
return fns
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
# main API entry point
|
|
372
|
+
def create_libero_envs(
|
|
373
|
+
task: str,
|
|
374
|
+
n_envs: int,
|
|
375
|
+
gym_kwargs: dict[str, Any] | None = None,
|
|
376
|
+
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
|
377
|
+
init_states: bool = True,
|
|
378
|
+
env_cls: type[gym.vector.SyncVectorEnv] | type[gym.vector.AsyncVectorEnv] | None = None,
|
|
379
|
+
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
|
380
|
+
"""
|
|
381
|
+
Create vectorized LIBERO environments with a consistent return shape.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
|
|
385
|
+
Notes:
|
|
386
|
+
- n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
|
|
387
|
+
- `task` can be a single suite or a comma-separated list of suites.
|
|
388
|
+
- You may pass `task_ids` (dict[str, list[int] | None]) inside `gym_kwargs` to restrict tasks per suite.
|
|
389
|
+
"""
|
|
390
|
+
if env_cls is None or not callable(env_cls):
|
|
391
|
+
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
|
|
392
|
+
if not isinstance(n_envs, int) or n_envs <= 0:
|
|
393
|
+
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
|
394
|
+
|
|
395
|
+
gym_kwargs = dict(gym_kwargs or {})
|
|
396
|
+
task_ids_filter: dict[str, list[int] | None] | None = gym_kwargs.pop("task_ids", None)
|
|
397
|
+
|
|
398
|
+
camera_names = _parse_camera_names(camera_name)
|
|
399
|
+
suite_names = [s.strip() for s in str(task).split(",") if s.strip()]
|
|
400
|
+
if not suite_names:
|
|
401
|
+
raise ValueError("`task` must contain at least one LIBERO suite name.")
|
|
402
|
+
|
|
403
|
+
acc_print(
|
|
404
|
+
f"Creating LIBERO envs | suites={suite_names} | n_envs(per task)={n_envs} | init_states={init_states}"
|
|
405
|
+
)
|
|
406
|
+
if task_ids_filter is not None:
|
|
407
|
+
# No tasks selected → return empty dict.
|
|
408
|
+
# This happens when you have more accelerator processes than evaluation tasks.
|
|
409
|
+
if len(task_ids_filter) == 0:
|
|
410
|
+
acc_print("Empty task_ids specified, returning empty dict.")
|
|
411
|
+
return {}
|
|
412
|
+
|
|
413
|
+
acc_print(f"Restricting to task_ids={task_ids_filter}")
|
|
414
|
+
|
|
415
|
+
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
|
416
|
+
|
|
417
|
+
for suite_name in suite_names:
|
|
418
|
+
suite = _get_suite(suite_name)
|
|
419
|
+
total = len(suite.tasks)
|
|
420
|
+
selected = _select_task_ids(total, task_ids_filter and task_ids_filter[suite_name])
|
|
421
|
+
|
|
422
|
+
if not selected:
|
|
423
|
+
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
|
424
|
+
|
|
425
|
+
for tid in selected:
|
|
426
|
+
fns = _make_env_fns(
|
|
427
|
+
suite=suite,
|
|
428
|
+
suite_name=suite_name,
|
|
429
|
+
task_id=tid,
|
|
430
|
+
n_envs=n_envs,
|
|
431
|
+
camera_names=camera_names,
|
|
432
|
+
init_states=init_states,
|
|
433
|
+
gym_kwargs=gym_kwargs,
|
|
434
|
+
)
|
|
435
|
+
out[suite_name][tid] = env_cls(fns)
|
|
436
|
+
acc_print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
|
437
|
+
|
|
438
|
+
# return plain dicts for predictability
|
|
439
|
+
return {suite: dict(task_map) for suite, task_map in out.items()}
|
opentau/envs/utils.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
r"""This module contains utility functions for environments."""
|
|
19
|
+
|
|
20
|
+
import warnings
|
|
21
|
+
from collections.abc import Mapping, Sequence
|
|
22
|
+
from functools import singledispatch
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
import gymnasium as gym
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
from torch import Tensor
|
|
29
|
+
from torchvision.transforms import Compose, Resize, ToTensor
|
|
30
|
+
|
|
31
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
32
|
+
from opentau.datasets.lerobot_dataset import BaseDataset
|
|
33
|
+
from opentau.utils.accelerate_utils import get_proc_accelerator
|
|
34
|
+
from opentau.utils.utils import auto_torch_device
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def preprocess_observation(np_observations: dict, cfg: TrainPipelineConfig) -> dict[str, Tensor]:
|
|
38
|
+
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
|
39
|
+
"""Convert environment observation to OpenTau format observation.
|
|
40
|
+
Args:
|
|
41
|
+
np_observations: Dictionary of observation batches from a Gym vector environment.
|
|
42
|
+
cfg: Training configuration that contains max_state_dim, num_cams, resolution, etc.
|
|
43
|
+
Returns:
|
|
44
|
+
Dictionary of observation batches with keys renamed to OpenTau format and values as tensors.
|
|
45
|
+
"""
|
|
46
|
+
# map to expected inputs for the policy
|
|
47
|
+
return_observations = {}
|
|
48
|
+
img_transform = Compose([ToTensor(), Resize(cfg.resolution, antialias=True)])
|
|
49
|
+
|
|
50
|
+
if "pixels" in np_observations:
|
|
51
|
+
assert isinstance(np_observations["pixels"], dict)
|
|
52
|
+
imgs: dict[str, np.ndarray] = np_observations["pixels"]
|
|
53
|
+
|
|
54
|
+
for imgkey, img in imgs.items():
|
|
55
|
+
return_observations[imgkey] = torch.stack([img_transform(img) for img in img])
|
|
56
|
+
|
|
57
|
+
if "environment_state" in np_observations:
|
|
58
|
+
env_state = torch.from_numpy(np_observations["environment_state"]).float()
|
|
59
|
+
if env_state.dim() == 1:
|
|
60
|
+
env_state = env_state.unsqueeze(0)
|
|
61
|
+
|
|
62
|
+
return_observations["environment_state"] = env_state
|
|
63
|
+
|
|
64
|
+
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
|
65
|
+
agent_pos = torch.from_numpy(np_observations["agent_pos"]).float()
|
|
66
|
+
if agent_pos.dim() == 1:
|
|
67
|
+
agent_pos = agent_pos.unsqueeze(0)
|
|
68
|
+
|
|
69
|
+
# Preprocess so that agent_pos has the same dimension as max_state_dim
|
|
70
|
+
agent_pos = BaseDataset.pad_vector(agent_pos, cfg.max_state_dim)
|
|
71
|
+
return_observations["state"] = agent_pos
|
|
72
|
+
|
|
73
|
+
batch_size = agent_pos.shape[0]
|
|
74
|
+
# add padding flags for cameras if needed
|
|
75
|
+
if cfg.num_cams > 0:
|
|
76
|
+
return_observations["img_is_pad"] = torch.zeros((batch_size, cfg.num_cams), dtype=torch.bool)
|
|
77
|
+
|
|
78
|
+
# convert all floating point tensors to bfloat16 to save memory
|
|
79
|
+
acc = get_proc_accelerator()
|
|
80
|
+
device = auto_torch_device() if acc is None else acc.device
|
|
81
|
+
|
|
82
|
+
for k, v in return_observations.items():
|
|
83
|
+
if isinstance(v, Tensor):
|
|
84
|
+
dtype = torch.bfloat16 if v.dtype.is_floating_point else v.dtype
|
|
85
|
+
return_observations[k] = v.to(device=device, dtype=dtype)
|
|
86
|
+
|
|
87
|
+
return return_observations
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
|
|
91
|
+
r"""Checks if all environments in a vectorized environment are of the same type.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
env: A vectorized Gym environment (SyncVectorEnv or AsyncVectorEnv).
|
|
95
|
+
Returns:
|
|
96
|
+
True if all environments are of the same type, False otherwise.
|
|
97
|
+
"""
|
|
98
|
+
if not isinstance(env, (gym.vector.SyncVectorEnv, gym.vector.AsyncVectorEnv)):
|
|
99
|
+
raise ValueError("Only gym.vector.SyncVectorEnv and gym.vector.AsyncVectorEnv are supported for now.")
|
|
100
|
+
|
|
101
|
+
types = env.call("get_wrapper_attr", "__class__")
|
|
102
|
+
first_type = types[0]
|
|
103
|
+
return all(t == first_type for t in types)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
|
107
|
+
r"""Checks if all environments in a vectorized environment have 'task_description' or 'task' attributes.
|
|
108
|
+
A warning will be raised if any environment is missing these attributes.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
env: A vectorized Gym environment (SyncVectorEnv or AsyncVectorEnv).
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If the environment is not a SyncVectorEnv or AsyncVectorEnv.
|
|
114
|
+
"""
|
|
115
|
+
with warnings.catch_warnings():
|
|
116
|
+
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
|
117
|
+
|
|
118
|
+
if not isinstance(env, (gym.vector.SyncVectorEnv, gym.vector.AsyncVectorEnv)):
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"Only gym.vector.SyncVectorEnv and gym.vector.AsyncVectorEnv are supported for now."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
task_desc_set = env.call("has_wrapper_attr", "task_description")
|
|
124
|
+
task_set = env.call("has_wrapper_attr", "task")
|
|
125
|
+
if not all(td or t for td, t in zip(task_desc_set, task_set, strict=True)):
|
|
126
|
+
warnings.warn(
|
|
127
|
+
"At least 1 environment does not have 'task_description' or 'task'. Some policies require these features.",
|
|
128
|
+
UserWarning,
|
|
129
|
+
stacklevel=2,
|
|
130
|
+
)
|
|
131
|
+
if not are_all_envs_same_type(env):
|
|
132
|
+
warnings.warn(
|
|
133
|
+
"The environments have different types. Make sure you infer the right task from each environment.",
|
|
134
|
+
UserWarning,
|
|
135
|
+
stacklevel=2,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
|
140
|
+
r"""Adds task feature to the observation dict with respect to the first environment attribute.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
env: A vectorized Gym environment (SyncVectorEnv or AsyncVectorEnv).
|
|
144
|
+
observation: A dictionary of observations from the vectorized environment, which will be modified in place.
|
|
145
|
+
Returns:
|
|
146
|
+
The updated observation dictionary with the 'prompt' key added.
|
|
147
|
+
"""
|
|
148
|
+
if not isinstance(env, (gym.vector.SyncVectorEnv, gym.vector.AsyncVectorEnv)):
|
|
149
|
+
raise ValueError("Only gym.vector.SyncVectorEnv and gym.vector.AsyncVectorEnv are supported for now.")
|
|
150
|
+
|
|
151
|
+
task_result = [""] * env.num_envs
|
|
152
|
+
for task_key in ["task_description", "task"]:
|
|
153
|
+
tasks = env.call("get_wrapper_attr", task_key)
|
|
154
|
+
if len(tasks) != env.num_envs:
|
|
155
|
+
raise ValueError(f"Environment returned {len(tasks)} task(s); expected {env.num_envs}.")
|
|
156
|
+
for i, t in enumerate(tasks):
|
|
157
|
+
if task_result[i] == "" and isinstance(t, str) and t != "":
|
|
158
|
+
task_result[i] = t
|
|
159
|
+
|
|
160
|
+
observation["prompt"] = task_result
|
|
161
|
+
return observation
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _close_single_env(env: Any) -> None:
|
|
165
|
+
try:
|
|
166
|
+
env.close()
|
|
167
|
+
except Exception as exc:
|
|
168
|
+
print(f"Exception while closing env {env}: {exc}")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@singledispatch
|
|
172
|
+
def close_envs(obj: Any) -> None:
|
|
173
|
+
"""Close a single environment, a list of environments, or a dictionary of environments."""
|
|
174
|
+
raise NotImplementedError(f"close_envs not implemented for type {type(obj).__name__}")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@close_envs.register
|
|
178
|
+
def _(env: Mapping) -> None:
|
|
179
|
+
for v in env.values():
|
|
180
|
+
if isinstance(v, Mapping):
|
|
181
|
+
close_envs(v)
|
|
182
|
+
elif hasattr(v, "close"):
|
|
183
|
+
_close_single_env(v)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@close_envs.register
|
|
187
|
+
def _(envs: Sequence) -> None:
|
|
188
|
+
if isinstance(envs, (str, bytes)):
|
|
189
|
+
return
|
|
190
|
+
for v in envs:
|
|
191
|
+
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
|
|
192
|
+
close_envs(v)
|
|
193
|
+
elif hasattr(v, "close"):
|
|
194
|
+
_close_single_env(v)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@close_envs.register
|
|
198
|
+
def _(env: gym.Env) -> None:
|
|
199
|
+
_close_single_env(env)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@close_envs.register
|
|
203
|
+
def _(env: None) -> None:
|
|
204
|
+
pass
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from .optimizers import OptimizerConfig as OptimizerConfig
|