opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,137 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Utilities for file I/O operations.
18
+
19
+ This module provides functions for reading and writing JSON files, saving videos,
20
+ and deserializing JSON data into structured objects with type checking.
21
+ """
22
+
23
+ import json
24
+ import warnings
25
+ from pathlib import Path
26
+ from typing import TypeVar
27
+
28
+ import imageio
29
+
30
+ JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
31
+ T = TypeVar("T", bound=JsonLike)
32
+
33
+
34
+ def write_video(video_path: str | Path, stacked_frames: list, fps: float) -> None:
35
+ """Write a list of frames to a video file.
36
+
37
+ Args:
38
+ video_path: Path where the video file will be saved.
39
+ stacked_frames: List of image frames to write.
40
+ fps: Frames per second for the output video.
41
+ """
42
+ # Filter out DeprecationWarnings raised from pkg_resources
43
+ with warnings.catch_warnings():
44
+ warnings.filterwarnings(
45
+ "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
46
+ )
47
+ imageio.mimsave(video_path, stacked_frames, fps=fps)
48
+
49
+
50
+ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
51
+ """Load JSON data and recursively fill an object with matching structure.
52
+
53
+ Loads the JSON data from fpath and recursively fills obj with the
54
+ corresponding values (strictly matching structure and types).
55
+ Tuples in obj are expected to be lists in the JSON data, which will be
56
+ converted back into tuples.
57
+
58
+ Args:
59
+ fpath: Path to the JSON file to load.
60
+ obj: Template object with the desired structure and types.
61
+
62
+ Returns:
63
+ Object with the same structure as obj, filled with values from the JSON file.
64
+
65
+ Raises:
66
+ TypeError: If structure or types don't match between JSON and obj.
67
+ ValueError: If dictionary keys or list/tuple lengths don't match.
68
+ """
69
+ with open(fpath, encoding="utf-8") as f:
70
+ data = json.load(f)
71
+
72
+ def _deserialize(target, source):
73
+ """
74
+ Recursively overwrite the structure in `target` with data from `source`,
75
+ performing strict checks on structure and type.
76
+ Returns the updated version of `target` (especially important for tuples).
77
+ """
78
+
79
+ # If the target is a dictionary, source must be a dictionary as well.
80
+ if isinstance(target, dict):
81
+ if not isinstance(source, dict):
82
+ raise TypeError(f"Type mismatch: expected dict, got {type(source)}")
83
+
84
+ # Check that they have exactly the same set of keys.
85
+ if target.keys() != source.keys():
86
+ raise ValueError(
87
+ f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
88
+ )
89
+
90
+ # Recursively update each key.
91
+ for k in target:
92
+ target[k] = _deserialize(target[k], source[k])
93
+
94
+ return target
95
+
96
+ # If the target is a list, source must be a list as well.
97
+ elif isinstance(target, list):
98
+ if not isinstance(source, list):
99
+ raise TypeError(f"Type mismatch: expected list, got {type(source)}")
100
+
101
+ # Check length
102
+ if len(target) != len(source):
103
+ raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
104
+
105
+ # Recursively update each element.
106
+ for i in range(len(target)):
107
+ target[i] = _deserialize(target[i], source[i])
108
+
109
+ return target
110
+
111
+ # If the target is a tuple, the source must be a list in JSON,
112
+ # which we'll convert back to a tuple.
113
+ elif isinstance(target, tuple):
114
+ if not isinstance(source, list):
115
+ raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
116
+
117
+ if len(target) != len(source):
118
+ raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
119
+
120
+ # Convert each element, forming a new tuple.
121
+ converted_items = []
122
+ for t_item, s_item in zip(target, source, strict=False):
123
+ converted_items.append(_deserialize(t_item, s_item))
124
+
125
+ # Return a brand new tuple (tuples are immutable in Python).
126
+ return tuple(converted_items)
127
+
128
+ # Otherwise, we're dealing with a "primitive" (int, float, str, bool, None).
129
+ else:
130
+ # Check the exact type. If these must match 1:1, do:
131
+ if type(target) is not type(source):
132
+ raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
133
+ return source
134
+
135
+ # Perform the in-place/recursive deserialization
136
+ updated_obj = _deserialize(obj, data)
137
+ return updated_obj
@@ -0,0 +1,214 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for working with the LIBERO robotics environment.
16
+
17
+ This module provides functions for converting LIBERO observations to PyTorch tensors,
18
+ summarizing LIBERO evaluation results, and recording observations from LIBERO environments.
19
+ """
20
+
21
+ import logging
22
+ from pathlib import Path
23
+
24
+ import imageio
25
+ import numpy as np
26
+ import torch
27
+ from einops import rearrange
28
+ from robosuite.utils.transform_utils import quat2axisangle
29
+
30
+
31
+ def rotate_numpy_image(image: np.ndarray) -> np.ndarray:
32
+ """Rotate and normalize a numpy image array.
33
+
34
+ Args:
35
+ image: Input image array in HWC format with values in [0, 255].
36
+
37
+ Returns:
38
+ Rotated and normalized image array in CHW format with values in [0, 1].
39
+ """
40
+ image = image.astype(float) / 255.0
41
+ image = np.rot90(image, 2)
42
+ return rearrange(image, "H W C -> C H W")
43
+
44
+
45
+ def _libero2np(obs: dict[str, np.ndarray], cfg) -> dict[str, str | np.ndarray]:
46
+ """Convert LIBERO observation dictionary to numpy format.
47
+
48
+ Args:
49
+ obs: LIBERO observation dictionary containing robot state and images.
50
+ cfg: Configuration object with task language, state dimensions, etc.
51
+
52
+ Returns:
53
+ Dictionary with converted observations in numpy format, including camera
54
+ images, state, prompt, and padding flags.
55
+ """
56
+ eef_pos = obs["robot0_eef_pos"]
57
+ eef_angle = quat2axisangle(obs["robot0_eef_quat"])
58
+ gripper_pos = obs["robot0_gripper_qpos"]
59
+
60
+ state = np.hstack((eef_pos, eef_angle, gripper_pos))
61
+
62
+ agent_view = rotate_numpy_image(obs["agentview_image"])
63
+ wrist_view = rotate_numpy_image(obs["robot0_eye_in_hand_image"])
64
+
65
+ return {
66
+ "camera0": agent_view,
67
+ "camera1": wrist_view,
68
+ "prompt": cfg.libero.task.language,
69
+ "state": np.pad(state, (0, cfg.max_state_dim - len(state))),
70
+ "img_is_pad": np.zeros(cfg.num_cams, dtype=bool),
71
+ "action_is_pad": np.zeros(cfg.action_chunk, dtype=bool),
72
+ }
73
+
74
+
75
+ def _np2torch(
76
+ np_input: dict[str, str | np.ndarray], device: str, dtype: torch.dtype
77
+ ) -> dict[str, str | torch.Tensor]:
78
+ """Convert numpy arrays in dictionary to PyTorch tensors.
79
+
80
+ Args:
81
+ np_input: Dictionary containing numpy arrays and strings.
82
+ device: Target device for tensors (e.g., 'cuda', 'cpu').
83
+ dtype: Target dtype for floating point tensors.
84
+
85
+ Returns:
86
+ Dictionary with numpy arrays converted to PyTorch tensors on the
87
+ specified device. String values are preserved as-is.
88
+
89
+ Raises:
90
+ TypeError: If a value type is not supported (not str or np.ndarray).
91
+ """
92
+ torch_input = {}
93
+ for k, v in np_input.items():
94
+ if isinstance(v, str):
95
+ torch_input[k] = v
96
+ elif isinstance(v, np.ndarray):
97
+ # .copy() ensures the array is contiguous for PyTorch to use it
98
+ tensor = torch.tensor(v.copy())
99
+ if tensor.dtype.is_floating_point:
100
+ tensor = tensor.to(dtype=dtype)
101
+ torch_input[k] = tensor.to(device)
102
+ else:
103
+ raise TypeError(f"Unsupported type {type(v)} for key {k}.")
104
+ return torch_input
105
+
106
+
107
+ def libero2torch(
108
+ obs: dict[str, np.ndarray], cfg, device: str, dtype: torch.dtype
109
+ ) -> dict[str, str | torch.Tensor]:
110
+ """Convert LIBERO observation to PyTorch tensors.
111
+
112
+ Args:
113
+ obs: LIBERO observation dictionary containing robot state and images.
114
+ cfg: Configuration object with task language, state dimensions, etc.
115
+ device: Target device for tensors (e.g., 'cuda', 'cpu').
116
+ dtype: Target dtype for floating point tensors.
117
+
118
+ Returns:
119
+ Dictionary with observations converted to PyTorch tensors on the
120
+ specified device, including camera images, state, prompt, and padding flags.
121
+ """
122
+ np_input = _libero2np(obs, cfg)
123
+ torch_input = _np2torch(np_input, device, dtype)
124
+ return torch_input
125
+
126
+
127
+ def summarize_libero_results(results: list[int]) -> dict:
128
+ """Summarize LIBERO evaluation results.
129
+
130
+ Args:
131
+ results: List of integer results where:
132
+ - Positive values indicate success (number of steps taken).
133
+ - -1 indicates failure.
134
+ - -2 indicates crash.
135
+
136
+ Returns:
137
+ Dictionary containing summary statistics including success/failure/crash
138
+ rates, counts, indices, and average steps taken for successful episodes.
139
+ """
140
+ if not results:
141
+ return {"message": "No results to summarize."}
142
+
143
+ success_indices = [i for i, r in enumerate(results) if r >= 0]
144
+ failure_indices = [i for i, r in enumerate(results) if r == -1]
145
+ crashed_indices = [i for i, r in enumerate(results) if r == -2]
146
+
147
+ success_rate = len(success_indices) / len(results)
148
+ failure_rate = len(failure_indices) / len(results)
149
+ crashed_rate = len(crashed_indices) / len(results)
150
+
151
+ avg_steps_taken = float(np.mean([r for r in results if r >= 0])) if success_indices else None
152
+
153
+ return {
154
+ "total_simulations": len(results),
155
+ "success_indices": success_indices,
156
+ "failure_indices": failure_indices,
157
+ "crashed_indices": crashed_indices,
158
+ "success_count": len(success_indices),
159
+ "failure_count": len(failure_indices),
160
+ "crashed_count": len(crashed_indices),
161
+ "success_rate": success_rate,
162
+ "failure_rate": failure_rate,
163
+ "crashed_rate": crashed_rate,
164
+ "steps_taken": results,
165
+ "avg_steps_taken_until_success": avg_steps_taken,
166
+ }
167
+
168
+
169
+ class LiberoObservationRecorder:
170
+ """Context manager for recording LIBERO observations to video files.
171
+
172
+ This class is not multi-processing safe. Each process should use a different
173
+ (folder, camera_name) pair.
174
+
175
+ Args:
176
+ folder: Directory path where video files will be saved. If None, recording
177
+ is disabled.
178
+ camera_names: List of camera names to record. If None, no cameras are recorded.
179
+ fps: Frames per second for the output videos. Defaults to 10.
180
+ extension: Video file extension. Defaults to "mp4".
181
+ """
182
+
183
+ def __init__(self, folder, camera_names=None, fps=10, extension="mp4"):
184
+ if folder is None:
185
+ logging.debug("No folder specified for video recording. Skipping.")
186
+ self.writers = []
187
+ self.camera_names = []
188
+ return
189
+
190
+ self.camera_names = camera_names or []
191
+ folder = Path(folder)
192
+ Path(folder).mkdir(parents=True, exist_ok=True)
193
+ video_files = [folder / f"{cam}.{extension}" for cam in self.camera_names]
194
+ logging.debug("Creating video files: %s", video_files)
195
+ self.writers = [imageio.get_writer(vf, fps=fps) for vf in video_files]
196
+
197
+ def __enter__(self):
198
+ return self
199
+
200
+ def record(self, obs):
201
+ """Record a single observation frame.
202
+
203
+ Args:
204
+ obs: Observation dictionary containing camera images keyed by camera name.
205
+ """
206
+ for writer, camera in zip(self.writers, self.camera_names, strict=True):
207
+ writer.append_data(np.rot90(obs[camera], k=2))
208
+
209
+ def __exit__(self, exc_type, exc_val, exc_tb):
210
+ logging.debug("Closing video writers.")
211
+ for writer in self.writers:
212
+ writer.close()
213
+ logging.debug("Video writers closed.")
214
+ return False