opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
opentau/envs/libero.py ADDED
@@ -0,0 +1,439 @@
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ r"""This module provides an environment wrapper for LIBERO tasks."""
17
+
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ from collections import defaultdict
22
+ from collections.abc import Callable, Iterable, Mapping, Sequence
23
+ from functools import partial
24
+ from pathlib import Path
25
+ from typing import Any
26
+
27
+ import gymnasium as gym
28
+ import numpy as np
29
+ import torch
30
+ from gymnasium import spaces
31
+ from libero.libero import benchmark, get_libero_path
32
+ from libero.libero.envs import OffScreenRenderEnv
33
+ from robosuite.utils.transform_utils import quat2axisangle
34
+
35
+ from opentau.utils.accelerate_utils import acc_print
36
+
37
+
38
+ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
39
+ """Normalize camera_name into a non-empty list of strings."""
40
+ if isinstance(camera_name, str):
41
+ cams = [c.strip() for c in camera_name.split(",") if c.strip()]
42
+ elif isinstance(camera_name, (list, tuple)):
43
+ cams = [str(c).strip() for c in camera_name if str(c).strip()]
44
+ else:
45
+ raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
46
+ if not cams:
47
+ raise ValueError("camera_name resolved to an empty list.")
48
+ return cams
49
+
50
+
51
+ def _get_suite(name: str) -> benchmark.Benchmark:
52
+ """Instantiate a LIBERO suite by name with clear validation."""
53
+ bench = benchmark.get_benchmark_dict()
54
+ if name not in bench:
55
+ raise ValueError(f"Unknown LIBERO suite '{name}'. Available: {', '.join(sorted(bench.keys()))}")
56
+ suite = bench[name]()
57
+ if not getattr(suite, "tasks", None):
58
+ raise ValueError(f"Suite '{name}' has no tasks.")
59
+ return suite
60
+
61
+
62
+ def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
63
+ """Validate/normalize task ids. If None → all tasks."""
64
+ if task_ids is None:
65
+ return list(range(total_tasks))
66
+ ids = sorted({int(t) for t in task_ids})
67
+ for t in ids:
68
+ if t < 0 or t >= total_tasks:
69
+ raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
70
+ return ids
71
+
72
+
73
+ def _get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
74
+ init_states_path = (
75
+ Path(get_libero_path("init_states"))
76
+ / task_suite.tasks[i].problem_folder
77
+ / task_suite.tasks[i].init_states_file
78
+ )
79
+ init_states = torch.load(init_states_path, weights_only=False) # nosec B614
80
+ return init_states
81
+
82
+
83
+ def get_libero_dummy_action() -> list[float | int]:
84
+ """Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
85
+ return [0, 0, 0, 0, 0, 0, -1]
86
+
87
+
88
+ OBS_STATE_DIM = 8
89
+ ACTION_DIM = 7
90
+ AGENT_POS_LOW = -1000.0
91
+ AGENT_POS_HIGH = 1000.0
92
+ ACTION_LOW = -1.0
93
+ ACTION_HIGH = 1.0
94
+ TASK_SUITE_MAX_STEPS: dict[str, int] = {
95
+ "libero_spatial": 280, # longest training demo has 193 steps
96
+ "libero_object": 280, # longest training demo has 254 steps
97
+ "libero_goal": 300, # longest training demo has 270 steps
98
+ "libero_10": 520, # longest training demo has 505 steps
99
+ "libero_90": 400, # longest training demo has 373 steps
100
+ }
101
+
102
+
103
+ class LiberoEnv(gym.Env):
104
+ r"""Environment wrapper for LIBERO tasks."""
105
+
106
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
107
+
108
+ def __init__(
109
+ self,
110
+ task_suite: Any,
111
+ task_id: int,
112
+ task_suite_name: str,
113
+ camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
114
+ obs_type: str = "pixels_agent_pos",
115
+ render_mode: str = "rgb_array",
116
+ observation_width: int = 256,
117
+ observation_height: int = 256,
118
+ visualization_width: int = 640,
119
+ visualization_height: int = 480,
120
+ init_states: bool = True,
121
+ episode_index: int = 0,
122
+ camera_name_mapping: dict[str, list[str]] | None = None,
123
+ num_steps_wait: int = 10,
124
+ render_cam: str | None = None,
125
+ ):
126
+ r"""Initialize the LiberoEnv.
127
+ Args:
128
+ task_suite: The LIBERO task suite to use.
129
+ task_id: The ID of the task within the suite.
130
+ task_suite_name: The name of the task suite.
131
+ camera_name: The name(s) of the camera(s) to use for observations. If a string, can be comma-separated.
132
+ obs_type: The type of observation to return. Options are 'pixels' or 'pixels
133
+ render_mode: The render mode for the environment.
134
+ observation_width: The width of the observation images.
135
+ observation_height: The height of the observation images.
136
+ visualization_width: The width of the visualization window.
137
+ visualization_height: The height of the visualization window.
138
+ init_states: Whether to use predefined initial states for the tasks.
139
+ episode_index: The index of the episode for selecting initial states.
140
+ camera_name_mapping: Optional mapping from raw camera names to desired observation keys.
141
+ num_steps_wait: Number of no-op steps to take after reset to stabilize the environment.
142
+ render_cam: The camera name to use for rendering. If None, uses the first camera.
143
+ """
144
+ super().__init__()
145
+ self.task_id = task_id
146
+ self.obs_type = obs_type
147
+ self.render_mode = render_mode
148
+ self.observation_width = observation_width
149
+ self.observation_height = observation_height
150
+ self.visualization_width = visualization_width
151
+ self.visualization_height = visualization_height
152
+ self.init_states = init_states
153
+ self.camera_name = _parse_camera_names(
154
+ camera_name
155
+ ) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
156
+ self.render_cam = render_cam
157
+
158
+ # Map raw camera names to "image1" and "image2".
159
+ # The preprocessing step `preprocess_observation` will then prefix these with `.images.*`,
160
+ # following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`).
161
+ # This ensures the policy consistently receives observations in the
162
+ # expected format regardless of the original camera naming.
163
+ if camera_name_mapping is None:
164
+ camera_name_mapping = {
165
+ "agentview_image": ["camera0"],
166
+ "robot0_eye_in_hand_image": ["camera1"],
167
+ }
168
+ self.camera_name_mapping = camera_name_mapping
169
+ for cam in self.camera_name_mapping:
170
+ assert not isinstance(self.camera_name_mapping[cam], str), (
171
+ "camera_name_mapping values must be lists of strings; "
172
+ f"got string {self.camera_name_mapping[cam]} for {cam} instead"
173
+ )
174
+ self.num_steps_wait = num_steps_wait
175
+ self.episode_index = episode_index
176
+ # Load once and keep
177
+ self._init_states = _get_task_init_states(task_suite, self.task_id) if self.init_states else None
178
+ self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
179
+
180
+ self._env = self._make_envs_task(task_suite, self.task_id)
181
+ default_steps = 500
182
+ self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
183
+
184
+ images = {}
185
+ for cam in self.camera_name:
186
+ for mapped_cam in self.camera_name_mapping[cam]:
187
+ images[mapped_cam] = spaces.Box(
188
+ low=0,
189
+ high=255,
190
+ shape=(self.observation_height, self.observation_width, 3),
191
+ dtype=np.uint8,
192
+ )
193
+
194
+ if self.obs_type == "state":
195
+ raise NotImplementedError(
196
+ "The 'state' observation type is not supported in LiberoEnv. "
197
+ "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
198
+ )
199
+
200
+ elif self.obs_type == "pixels":
201
+ self.observation_space = spaces.Dict(
202
+ {
203
+ "pixels": spaces.Dict(images),
204
+ }
205
+ )
206
+ elif self.obs_type == "pixels_agent_pos":
207
+ self.observation_space = spaces.Dict(
208
+ {
209
+ "pixels": spaces.Dict(images),
210
+ "agent_pos": spaces.Box(
211
+ low=AGENT_POS_LOW,
212
+ high=AGENT_POS_HIGH,
213
+ shape=(OBS_STATE_DIM,),
214
+ dtype=np.float64,
215
+ ),
216
+ }
217
+ )
218
+
219
+ self.action_space = spaces.Box(
220
+ low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
221
+ )
222
+
223
+ def render(self) -> np.ndarray:
224
+ r"""Render the environment and return a numpy array representing the RGB camera.
225
+ If `self.render_cam` is set, use that camera; otherwise, use the first camera."""
226
+ raw_obs = self._env.env._get_observations()
227
+ cams: dict[str, np.ndarray] = self._format_raw_obs(raw_obs)["pixels"]
228
+ # if `self.render_cam` is not set, use the first camera
229
+ render_cam = self.render_cam or next(iter(cams))
230
+ return cams[render_cam]
231
+
232
+ def _make_envs_task(self, task_suite: Any, task_id: int = 0):
233
+ task = task_suite.get_task(task_id)
234
+ self.task = task.name
235
+ self.task_description = task.language
236
+ task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
237
+
238
+ env_args = {
239
+ "bddl_file_name": task_bddl_file,
240
+ "camera_heights": self.observation_height,
241
+ "camera_widths": self.observation_width,
242
+ }
243
+ env = OffScreenRenderEnv(**env_args)
244
+ env.reset()
245
+ return env
246
+
247
+ def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]:
248
+ images = {}
249
+ for camera_name in self.camera_name:
250
+ image = raw_obs[camera_name]
251
+ image = image[::-1, ::-1] # rotate 180 degrees
252
+ for mapped_cam in self.camera_name_mapping[camera_name]:
253
+ images[mapped_cam] = image.copy()
254
+ state = np.concatenate(
255
+ (
256
+ raw_obs["robot0_eef_pos"],
257
+ quat2axisangle(raw_obs["robot0_eef_quat"]),
258
+ raw_obs["robot0_gripper_qpos"],
259
+ )
260
+ )
261
+ agent_pos = state
262
+ if self.obs_type == "pixels":
263
+ return {"pixels": images.copy()}
264
+ if self.obs_type == "pixels_agent_pos":
265
+ return {
266
+ "pixels": images.copy(),
267
+ "agent_pos": agent_pos,
268
+ }
269
+ raise NotImplementedError(
270
+ f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
271
+ "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
272
+ )
273
+
274
+ def reset(self, seed=None, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
275
+ r"""Reset the environment with the given seed.
276
+
277
+ Args:
278
+ seed: The seed to use for resetting the environment.
279
+ Returns:
280
+ observation: The initial observation after reset.
281
+ info: Additional information about the reset.
282
+ """
283
+ super().reset(seed=seed)
284
+ self._env.seed(seed)
285
+ if self.init_states and self._init_states is not None:
286
+ self._env.set_init_state(self._init_states[self._init_state_id])
287
+ raw_obs = self._env.reset()
288
+
289
+ # After reset, objects may be unstable (slightly floating, intersecting, etc.).
290
+ # Step the simulator with a no-op action for a few frames so everything settles.
291
+ # Increasing this value can improve determinism and reproducibility across resets.
292
+ for _ in range(self.num_steps_wait):
293
+ raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
294
+ observation = self._format_raw_obs(raw_obs)
295
+ info = {"is_success": False}
296
+ return observation, info
297
+
298
+ def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
299
+ r"""Take a step in the environment with the given action.
300
+ Args:
301
+ action: The action to take.
302
+ Returns:
303
+ observation: The observation after taking the step.
304
+ reward: The reward obtained from taking the step.
305
+ terminated: Whether the episode has terminated.
306
+ truncated: Whether the episode was truncated.
307
+ info: Additional information about the step.
308
+ """
309
+ if action.ndim != 1:
310
+ raise ValueError(
311
+ f"Expected action to be 1-D (shape (action_dim,)), "
312
+ f"but got shape {action.shape} with ndim={action.ndim}"
313
+ )
314
+ if len(action) > ACTION_DIM:
315
+ action = action[:ACTION_DIM]
316
+ raw_obs, reward, done, info = self._env.step(action)
317
+
318
+ is_success = self._env.check_success()
319
+ terminated = done or is_success
320
+ info["is_success"] = is_success
321
+
322
+ observation = self._format_raw_obs(raw_obs)
323
+ if done:
324
+ self.reset()
325
+ info.update(
326
+ {
327
+ "task": self.task,
328
+ "task_id": self.task_id,
329
+ "done": done,
330
+ "is_success": is_success,
331
+ }
332
+ )
333
+ truncated = False
334
+ return observation, reward, terminated, truncated, info
335
+
336
+ def close(self):
337
+ r"""Close the environment and release any resources."""
338
+ self._env.close()
339
+
340
+
341
+ def _make_env_fns(
342
+ *,
343
+ suite,
344
+ suite_name: str,
345
+ task_id: int,
346
+ n_envs: int,
347
+ camera_names: list[str],
348
+ init_states: bool,
349
+ gym_kwargs: Mapping[str, Any],
350
+ ) -> list[Callable[[], LiberoEnv]]:
351
+ """Build n_envs factory callables for a single (suite, task_id)."""
352
+
353
+ def _make_env(episode_index: int, **kwargs) -> LiberoEnv:
354
+ local_kwargs = dict(kwargs)
355
+ return LiberoEnv(
356
+ task_suite=suite,
357
+ task_id=task_id,
358
+ task_suite_name=suite_name,
359
+ camera_name=camera_names,
360
+ init_states=init_states,
361
+ episode_index=episode_index,
362
+ **local_kwargs,
363
+ )
364
+
365
+ fns: list[Callable[[], LiberoEnv]] = []
366
+ for episode_index in range(n_envs):
367
+ fns.append(partial(_make_env, episode_index, **gym_kwargs))
368
+ return fns
369
+
370
+
371
+ # main API entry point
372
+ def create_libero_envs(
373
+ task: str,
374
+ n_envs: int,
375
+ gym_kwargs: dict[str, Any] | None = None,
376
+ camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
377
+ init_states: bool = True,
378
+ env_cls: type[gym.vector.SyncVectorEnv] | type[gym.vector.AsyncVectorEnv] | None = None,
379
+ ) -> dict[str, dict[int, gym.vector.VectorEnv]]:
380
+ """
381
+ Create vectorized LIBERO environments with a consistent return shape.
382
+
383
+ Returns:
384
+ dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
385
+ Notes:
386
+ - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
387
+ - `task` can be a single suite or a comma-separated list of suites.
388
+ - You may pass `task_ids` (dict[str, list[int] | None]) inside `gym_kwargs` to restrict tasks per suite.
389
+ """
390
+ if env_cls is None or not callable(env_cls):
391
+ raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
392
+ if not isinstance(n_envs, int) or n_envs <= 0:
393
+ raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
394
+
395
+ gym_kwargs = dict(gym_kwargs or {})
396
+ task_ids_filter: dict[str, list[int] | None] | None = gym_kwargs.pop("task_ids", None)
397
+
398
+ camera_names = _parse_camera_names(camera_name)
399
+ suite_names = [s.strip() for s in str(task).split(",") if s.strip()]
400
+ if not suite_names:
401
+ raise ValueError("`task` must contain at least one LIBERO suite name.")
402
+
403
+ acc_print(
404
+ f"Creating LIBERO envs | suites={suite_names} | n_envs(per task)={n_envs} | init_states={init_states}"
405
+ )
406
+ if task_ids_filter is not None:
407
+ # No tasks selected → return empty dict.
408
+ # This happens when you have more accelerator processes than evaluation tasks.
409
+ if len(task_ids_filter) == 0:
410
+ acc_print("Empty task_ids specified, returning empty dict.")
411
+ return {}
412
+
413
+ acc_print(f"Restricting to task_ids={task_ids_filter}")
414
+
415
+ out: dict[str, dict[int, Any]] = defaultdict(dict)
416
+
417
+ for suite_name in suite_names:
418
+ suite = _get_suite(suite_name)
419
+ total = len(suite.tasks)
420
+ selected = _select_task_ids(total, task_ids_filter and task_ids_filter[suite_name])
421
+
422
+ if not selected:
423
+ raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
424
+
425
+ for tid in selected:
426
+ fns = _make_env_fns(
427
+ suite=suite,
428
+ suite_name=suite_name,
429
+ task_id=tid,
430
+ n_envs=n_envs,
431
+ camera_names=camera_names,
432
+ init_states=init_states,
433
+ gym_kwargs=gym_kwargs,
434
+ )
435
+ out[suite_name][tid] = env_cls(fns)
436
+ acc_print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
437
+
438
+ # return plain dicts for predictability
439
+ return {suite: dict(task_map) for suite, task_map in out.items()}
opentau/envs/utils.py ADDED
@@ -0,0 +1,204 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ r"""This module contains utility functions for environments."""
19
+
20
+ import warnings
21
+ from collections.abc import Mapping, Sequence
22
+ from functools import singledispatch
23
+ from typing import Any
24
+
25
+ import gymnasium as gym
26
+ import numpy as np
27
+ import torch
28
+ from torch import Tensor
29
+ from torchvision.transforms import Compose, Resize, ToTensor
30
+
31
+ from opentau.configs.train import TrainPipelineConfig
32
+ from opentau.datasets.lerobot_dataset import BaseDataset
33
+ from opentau.utils.accelerate_utils import get_proc_accelerator
34
+ from opentau.utils.utils import auto_torch_device
35
+
36
+
37
+ def preprocess_observation(np_observations: dict, cfg: TrainPipelineConfig) -> dict[str, Tensor]:
38
+ # TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
39
+ """Convert environment observation to OpenTau format observation.
40
+ Args:
41
+ np_observations: Dictionary of observation batches from a Gym vector environment.
42
+ cfg: Training configuration that contains max_state_dim, num_cams, resolution, etc.
43
+ Returns:
44
+ Dictionary of observation batches with keys renamed to OpenTau format and values as tensors.
45
+ """
46
+ # map to expected inputs for the policy
47
+ return_observations = {}
48
+ img_transform = Compose([ToTensor(), Resize(cfg.resolution, antialias=True)])
49
+
50
+ if "pixels" in np_observations:
51
+ assert isinstance(np_observations["pixels"], dict)
52
+ imgs: dict[str, np.ndarray] = np_observations["pixels"]
53
+
54
+ for imgkey, img in imgs.items():
55
+ return_observations[imgkey] = torch.stack([img_transform(img) for img in img])
56
+
57
+ if "environment_state" in np_observations:
58
+ env_state = torch.from_numpy(np_observations["environment_state"]).float()
59
+ if env_state.dim() == 1:
60
+ env_state = env_state.unsqueeze(0)
61
+
62
+ return_observations["environment_state"] = env_state
63
+
64
+ # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
65
+ agent_pos = torch.from_numpy(np_observations["agent_pos"]).float()
66
+ if agent_pos.dim() == 1:
67
+ agent_pos = agent_pos.unsqueeze(0)
68
+
69
+ # Preprocess so that agent_pos has the same dimension as max_state_dim
70
+ agent_pos = BaseDataset.pad_vector(agent_pos, cfg.max_state_dim)
71
+ return_observations["state"] = agent_pos
72
+
73
+ batch_size = agent_pos.shape[0]
74
+ # add padding flags for cameras if needed
75
+ if cfg.num_cams > 0:
76
+ return_observations["img_is_pad"] = torch.zeros((batch_size, cfg.num_cams), dtype=torch.bool)
77
+
78
+ # convert all floating point tensors to bfloat16 to save memory
79
+ acc = get_proc_accelerator()
80
+ device = auto_torch_device() if acc is None else acc.device
81
+
82
+ for k, v in return_observations.items():
83
+ if isinstance(v, Tensor):
84
+ dtype = torch.bfloat16 if v.dtype.is_floating_point else v.dtype
85
+ return_observations[k] = v.to(device=device, dtype=dtype)
86
+
87
+ return return_observations
88
+
89
+
90
+ def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
91
+ r"""Checks if all environments in a vectorized environment are of the same type.
92
+
93
+ Args:
94
+ env: A vectorized Gym environment (SyncVectorEnv or AsyncVectorEnv).
95
+ Returns:
96
+ True if all environments are of the same type, False otherwise.
97
+ """
98
+ if not isinstance(env, (gym.vector.SyncVectorEnv, gym.vector.AsyncVectorEnv)):
99
+ raise ValueError("Only gym.vector.SyncVectorEnv and gym.vector.AsyncVectorEnv are supported for now.")
100
+
101
+ types = env.call("get_wrapper_attr", "__class__")
102
+ first_type = types[0]
103
+ return all(t == first_type for t in types)
104
+
105
+
106
+ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
107
+ r"""Checks if all environments in a vectorized environment have 'task_description' or 'task' attributes.
108
+ A warning will be raised if any environment is missing these attributes.
109
+
110
+ Args:
111
+ env: A vectorized Gym environment (SyncVectorEnv or AsyncVectorEnv).
112
+ Raises:
113
+ ValueError: If the environment is not a SyncVectorEnv or AsyncVectorEnv.
114
+ """
115
+ with warnings.catch_warnings():
116
+ warnings.simplefilter("once", UserWarning) # Apply filter only in this function
117
+
118
+ if not isinstance(env, (gym.vector.SyncVectorEnv, gym.vector.AsyncVectorEnv)):
119
+ raise ValueError(
120
+ "Only gym.vector.SyncVectorEnv and gym.vector.AsyncVectorEnv are supported for now."
121
+ )
122
+
123
+ task_desc_set = env.call("has_wrapper_attr", "task_description")
124
+ task_set = env.call("has_wrapper_attr", "task")
125
+ if not all(td or t for td, t in zip(task_desc_set, task_set, strict=True)):
126
+ warnings.warn(
127
+ "At least 1 environment does not have 'task_description' or 'task'. Some policies require these features.",
128
+ UserWarning,
129
+ stacklevel=2,
130
+ )
131
+ if not are_all_envs_same_type(env):
132
+ warnings.warn(
133
+ "The environments have different types. Make sure you infer the right task from each environment.",
134
+ UserWarning,
135
+ stacklevel=2,
136
+ )
137
+
138
+
139
+ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
140
+ r"""Adds task feature to the observation dict with respect to the first environment attribute.
141
+
142
+ Args:
143
+ env: A vectorized Gym environment (SyncVectorEnv or AsyncVectorEnv).
144
+ observation: A dictionary of observations from the vectorized environment, which will be modified in place.
145
+ Returns:
146
+ The updated observation dictionary with the 'prompt' key added.
147
+ """
148
+ if not isinstance(env, (gym.vector.SyncVectorEnv, gym.vector.AsyncVectorEnv)):
149
+ raise ValueError("Only gym.vector.SyncVectorEnv and gym.vector.AsyncVectorEnv are supported for now.")
150
+
151
+ task_result = [""] * env.num_envs
152
+ for task_key in ["task_description", "task"]:
153
+ tasks = env.call("get_wrapper_attr", task_key)
154
+ if len(tasks) != env.num_envs:
155
+ raise ValueError(f"Environment returned {len(tasks)} task(s); expected {env.num_envs}.")
156
+ for i, t in enumerate(tasks):
157
+ if task_result[i] == "" and isinstance(t, str) and t != "":
158
+ task_result[i] = t
159
+
160
+ observation["prompt"] = task_result
161
+ return observation
162
+
163
+
164
+ def _close_single_env(env: Any) -> None:
165
+ try:
166
+ env.close()
167
+ except Exception as exc:
168
+ print(f"Exception while closing env {env}: {exc}")
169
+
170
+
171
+ @singledispatch
172
+ def close_envs(obj: Any) -> None:
173
+ """Close a single environment, a list of environments, or a dictionary of environments."""
174
+ raise NotImplementedError(f"close_envs not implemented for type {type(obj).__name__}")
175
+
176
+
177
+ @close_envs.register
178
+ def _(env: Mapping) -> None:
179
+ for v in env.values():
180
+ if isinstance(v, Mapping):
181
+ close_envs(v)
182
+ elif hasattr(v, "close"):
183
+ _close_single_env(v)
184
+
185
+
186
+ @close_envs.register
187
+ def _(envs: Sequence) -> None:
188
+ if isinstance(envs, (str, bytes)):
189
+ return
190
+ for v in envs:
191
+ if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
192
+ close_envs(v)
193
+ elif hasattr(v, "close"):
194
+ _close_single_env(v)
195
+
196
+
197
+ @close_envs.register
198
+ def _(env: gym.Env) -> None:
199
+ _close_single_env(env)
200
+
201
+
202
+ @close_envs.register
203
+ def _(env: None) -> None:
204
+ pass
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .optimizers import OptimizerConfig as OptimizerConfig