opentau 0.1.1__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opentau-0.1.1/src/opentau.egg-info → opentau-0.2.0}/PKG-INFO +37 -17
- {opentau-0.1.1 → opentau-0.2.0}/README.md +26 -11
- {opentau-0.1.1 → opentau-0.2.0}/pyproject.toml +26 -7
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/configs/default.py +16 -0
- opentau-0.2.0/src/opentau/configs/deployment.py +85 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/configs/train.py +5 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/factory.py +43 -10
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/lerobot_dataset.py +19 -19
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/video_utils.py +11 -6
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pi05/configuration_pi05.py +9 -6
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pi05/modeling_pi05.py +296 -30
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pi05/paligemma_with_expert.py +20 -20
- opentau-0.2.0/src/opentau/scripts/grpc/__init__.py +19 -0
- opentau-0.2.0/src/opentau/scripts/grpc/client.py +601 -0
- opentau-0.2.0/src/opentau/scripts/grpc/robot_inference_pb2.py +61 -0
- opentau-0.2.0/src/opentau/scripts/grpc/robot_inference_pb2_grpc.py +210 -0
- opentau-0.2.0/src/opentau/scripts/grpc/server.py +313 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/launch.py +12 -4
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/train.py +94 -17
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/visualize_dataset.py +141 -38
- opentau-0.2.0/src/opentau/utils/transformers_patch.py +279 -0
- {opentau-0.1.1 → opentau-0.2.0/src/opentau.egg-info}/PKG-INFO +37 -17
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau.egg-info/SOURCES.txt +6 -3
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau.egg-info/entry_points.txt +1 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau.egg-info/requires.txt +10 -4
- opentau-0.1.1/src/opentau/scripts/libero_simulation_parallel.py +0 -356
- opentau-0.1.1/src/opentau/scripts/libero_simulation_sequential.py +0 -122
- opentau-0.1.1/src/opentau/scripts/visualize_dataset_html.py +0 -507
- opentau-0.1.1/src/opentau/utils/transformers_patch.py +0 -48
- {opentau-0.1.1 → opentau-0.2.0}/LICENSE +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/setup.cfg +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/__version__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/configs/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/configs/libero.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/configs/parser.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/configs/policies.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/configs/reward.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/configs/types.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/constants.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/backward_compatibility.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/compute_stats.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/dataset_mixture.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/grounding/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/grounding/base.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/grounding/clevr.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/grounding/cocoqa.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/grounding/dummy.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/grounding/pixmo.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/grounding/vsr.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/image_writer.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/online_buffer.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/push_dataset_to_hub/utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/sampler.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/standard_data_format_mapping.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/transforms.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/v2/convert_dataset_v1_to_v2.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/v21/_remove_language_instruction.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/v21/convert_dataset_v20_to_v21.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/datasets/v21/convert_stats.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/envs/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/envs/configs.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/envs/factory.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/envs/libero.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/envs/utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/optim/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/optim/factory.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/optim/optimizers.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/optim/schedulers.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/planner/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/planner/high_level_planner.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/planner/utils/memory.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/planner/utils/utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/factory.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/normalize.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pi0/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pi0/configuration_pi0.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pi0/modeling_pi0.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pi0/paligemma_with_expert.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pi05/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/pretrained.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/value/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/value/configuration_value.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/value/modeling_value.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/value/reward.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/policies/value/siglip_gemma.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/actions_mse_loss.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/bin_to_safetensors.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/compute_max_token_length.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/display_sys_info.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/download_libero_benchmarks.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/eval.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/export_to_onnx.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/fake_tensor_training.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/get_advantage_and_percentiles.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/high_level_planner_inference.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/inference.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/nav_high_level_planner_inference.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/scripts/zero_to_fp32.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/__init__.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/accelerate_utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/benchmark.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/fake_tensor.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/hub.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/import_utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/io_utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/libero.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/libero_dataset_recorder.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/logging_utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/monkey_patch.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/random_utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/train_utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau/utils/utils.py +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau.egg-info/dependency_links.txt +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/src/opentau.egg-info/top_level.txt +0 -0
- {opentau-0.1.1 → opentau-0.2.0}/tests/test_available.py +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opentau
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch
|
|
5
|
-
Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com
|
|
5
|
+
Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>
|
|
6
6
|
License: Apache-2.0
|
|
7
7
|
Project-URL: homepage, https://github.com/TensorAuto/OpenTau
|
|
8
8
|
Project-URL: issues, https://github.com/TensorAuto/OpenTau/issues
|
|
@@ -41,9 +41,9 @@ Requires-Dist: pynput>=1.7.7
|
|
|
41
41
|
Requires-Dist: pyzmq>=26.2.1
|
|
42
42
|
Requires-Dist: rerun-sdk>=0.21.0
|
|
43
43
|
Requires-Dist: termcolor>=2.4.0
|
|
44
|
-
Requires-Dist: torch
|
|
44
|
+
Requires-Dist: torch>=2.7.1
|
|
45
45
|
Requires-Dist: torchcodec<0.5.0,>=0.4.0; sys_platform != "win32" and (sys_platform != "linux" or (platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l")) and (sys_platform != "darwin" or platform_machine != "x86_64")
|
|
46
|
-
Requires-Dist: torchvision
|
|
46
|
+
Requires-Dist: torchvision>=0.22.1
|
|
47
47
|
Requires-Dist: wandb>=0.16.3
|
|
48
48
|
Requires-Dist: zarr>=2.17.0
|
|
49
49
|
Requires-Dist: scikit-learn>=1.7.1
|
|
@@ -52,7 +52,7 @@ Requires-Dist: onnxruntime>=1.22.1; sys_platform == "darwin" or platform_machine
|
|
|
52
52
|
Requires-Dist: onnxruntime-gpu>=1.22.0; (sys_platform == "linux" and platform_machine == "x86_64") or (sys_platform == "win32" and (platform_machine == "AMD64" or platform_machine == "x86_64"))
|
|
53
53
|
Requires-Dist: onnxscript>=0.3.1
|
|
54
54
|
Requires-Dist: onnx-ir>=0.1.4
|
|
55
|
-
Requires-Dist:
|
|
55
|
+
Requires-Dist: transformers==4.53.3
|
|
56
56
|
Requires-Dist: scipy>=1.15.2
|
|
57
57
|
Requires-Dist: pytest>=8.1.0
|
|
58
58
|
Requires-Dist: pytest-cov>=5.0.0
|
|
@@ -62,6 +62,10 @@ Requires-Dist: scikit-image>=0.23.2
|
|
|
62
62
|
Requires-Dist: pandas>=2.2.2
|
|
63
63
|
Requires-Dist: accelerate>=1.4.0
|
|
64
64
|
Requires-Dist: deepspeed>=0.17.1
|
|
65
|
+
Requires-Dist: gymnasium[other]>=0.29
|
|
66
|
+
Requires-Dist: grpcio>=1.60.0
|
|
67
|
+
Requires-Dist: grpcio-tools>=1.60.0
|
|
68
|
+
Requires-Dist: protobuf>=4.25.0
|
|
65
69
|
Provides-Extra: dev
|
|
66
70
|
Requires-Dist: pre-commit>=3.7.0; extra == "dev"
|
|
67
71
|
Requires-Dist: debugpy>=1.8.1; extra == "dev"
|
|
@@ -93,10 +97,11 @@ Requires-Dist: libero; extra == "libero"
|
|
|
93
97
|
Requires-Dist: numpy<2; extra == "libero"
|
|
94
98
|
Requires-Dist: gym<0.27,>=0.25; extra == "libero"
|
|
95
99
|
Requires-Dist: pyopengl-accelerate==3.1.7; sys_platform == "linux" and extra == "libero"
|
|
96
|
-
Requires-Dist: gymnasium[other]>=0.29; extra == "libero"
|
|
97
100
|
Requires-Dist: mujoco>=3.1.6; sys_platform == "linux" and extra == "libero"
|
|
98
101
|
Requires-Dist: pyopengl==3.1.7; sys_platform == "linux" and extra == "libero"
|
|
99
102
|
Requires-Dist: numpy==1.26.4; sys_platform == "linux" and extra == "libero"
|
|
103
|
+
Provides-Extra: urdf
|
|
104
|
+
Requires-Dist: rerun-sdk>=0.28.2; extra == "urdf"
|
|
100
105
|
Dynamic: license-file
|
|
101
106
|
|
|
102
107
|
<p align="center">
|
|
@@ -105,6 +110,19 @@ Dynamic: license-file
|
|
|
105
110
|
</a>
|
|
106
111
|
</p>
|
|
107
112
|
|
|
113
|
+
<p align="center">
|
|
114
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/cpu_test.yml?query=branch%3Amain"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/cpu_test.yml/badge.svg?branch=main" alt="CPU Tests"></a>
|
|
115
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/gpu_test.yml"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/gpu_test.yml/badge.svg" alt="Nightly GPU Tests"></a>
|
|
116
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/regression_test.yml"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/regression_test.yml/badge.svg" alt="Nightly Regression Tests"></a>
|
|
117
|
+
<a href="https://opentau.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/opentau/badge/?version=latest" alt="Documentation"></a>
|
|
118
|
+
<a href="https://pypi.org/project/opentau/"><img src="https://img.shields.io/pypi/v/opentau" alt="Version"></a>
|
|
119
|
+
<a href="https://pypi.org/project/opentau/"><img src="https://img.shields.io/pypi/status/opentau" alt="Status"></a>
|
|
120
|
+
<a href="https://www.python.org/downloads/"><img src="https://img.shields.io/pypi/pyversions/opentau" alt="Python versions"></a>
|
|
121
|
+
<a href="https://github.com/TensorAuto/OpenTau/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="License"></a>
|
|
122
|
+
<a href="https://hub.docker.com/r/tensorauto/opentau"><img src="https://img.shields.io/docker/v/tensorauto/opentau?label=Docker" alt="Docker"></a>
|
|
123
|
+
<a href="https://github.com/pre-commit/pre-commit"><img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit" alt="pre-commit"></a>
|
|
124
|
+
</p>
|
|
125
|
+
|
|
108
126
|
# OpenTau - Train VLA models with state-of-the-art techniques by Tensor
|
|
109
127
|
|
|
110
128
|
At Tensor, we are pushing the frontier of large foundation models for physical AI. In robot learning, a vision-language-action (VLA) model is a multimodal foundation model that integrates vision, language, and action. Today, VLA represents the leading approach for embodied AI, spanning autonomous driving, robot manipulation, and navigation.
|
|
@@ -122,17 +140,19 @@ Whether you use the official OpenPi codebase or LeRobot’s reimplementation, yo
|
|
|
122
140
|
|
|
123
141
|
OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we also use it internally to train our proprietary in-house models. Our goal is to help you train VLAs on any dataset while fully leveraging state-of-the-art techniques. We plan to continuously upgrade this repository to keep pace with the state of the art in the robotics community.
|
|
124
142
|
|
|
125
|
-
|
|
|
126
|
-
|
|
127
|
-
|
|
|
128
|
-
|
|
|
129
|
-
| Knowledge Insulation (KI) between VLM and Action Decoder |
|
|
130
|
-
|
|
|
131
|
-
|
|
|
132
|
-
|
|
|
133
|
-
|
|
|
134
|
-
|
|
|
135
|
-
|
|
|
143
|
+
| Features | OpenPi | LeRobot | **OpenTau** |
|
|
144
|
+
|---------------------------------------------------------:|:-----------------------:|:--------------------------------:|:-----------:|
|
|
145
|
+
| Co-training with Heterogeneous Datasets | ❌ | ❌ | ✅ |
|
|
146
|
+
| Discrete Actions Training in $\pi_{0.5}$ | ❌ | ❌ | ✅ |
|
|
147
|
+
| Knowledge Insulation (KI) between VLM and Action Decoder | ❌ | ❌ | ✅ |
|
|
148
|
+
| Dropout Layers in PaliGemma | ✅ (Jax) <br>❌ (PyTorch) | ❌ | ✅ |
|
|
149
|
+
| Multi-Node and Multi-GPU Training | ❌ | ✅ | ✅ |
|
|
150
|
+
| Fully Functioning $\pi_{0.5}$ Checkpoint | ✅ | ❌ <br> (Missing Text Embeddings) | ✅ |
|
|
151
|
+
| Visualize dataset with URDF models | ❌ | ❌ | ✅ |
|
|
152
|
+
| Simulation Environments for Evaluating Models | ❌ | ✅ | ✅ |
|
|
153
|
+
| Create Validation Splits During Training | ❌ | ❌ | ✅ |
|
|
154
|
+
| $\pi^{*}_{0.6}$ style Reinforcement Learning Pipeline | ❌ | ❌ | ✅ |
|
|
155
|
+
| Framework | Jax / PyTorch | PyTorch | PyTorch |
|
|
136
156
|
|
|
137
157
|
## Quick Start
|
|
138
158
|
If you are familiar with LeRobot, getting started with OpenTau is very easy.
|
|
@@ -4,6 +4,19 @@
|
|
|
4
4
|
</a>
|
|
5
5
|
</p>
|
|
6
6
|
|
|
7
|
+
<p align="center">
|
|
8
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/cpu_test.yml?query=branch%3Amain"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/cpu_test.yml/badge.svg?branch=main" alt="CPU Tests"></a>
|
|
9
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/gpu_test.yml"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/gpu_test.yml/badge.svg" alt="Nightly GPU Tests"></a>
|
|
10
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/regression_test.yml"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/regression_test.yml/badge.svg" alt="Nightly Regression Tests"></a>
|
|
11
|
+
<a href="https://opentau.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/opentau/badge/?version=latest" alt="Documentation"></a>
|
|
12
|
+
<a href="https://pypi.org/project/opentau/"><img src="https://img.shields.io/pypi/v/opentau" alt="Version"></a>
|
|
13
|
+
<a href="https://pypi.org/project/opentau/"><img src="https://img.shields.io/pypi/status/opentau" alt="Status"></a>
|
|
14
|
+
<a href="https://www.python.org/downloads/"><img src="https://img.shields.io/pypi/pyversions/opentau" alt="Python versions"></a>
|
|
15
|
+
<a href="https://github.com/TensorAuto/OpenTau/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="License"></a>
|
|
16
|
+
<a href="https://hub.docker.com/r/tensorauto/opentau"><img src="https://img.shields.io/docker/v/tensorauto/opentau?label=Docker" alt="Docker"></a>
|
|
17
|
+
<a href="https://github.com/pre-commit/pre-commit"><img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit" alt="pre-commit"></a>
|
|
18
|
+
</p>
|
|
19
|
+
|
|
7
20
|
# OpenTau - Train VLA models with state-of-the-art techniques by Tensor
|
|
8
21
|
|
|
9
22
|
At Tensor, we are pushing the frontier of large foundation models for physical AI. In robot learning, a vision-language-action (VLA) model is a multimodal foundation model that integrates vision, language, and action. Today, VLA represents the leading approach for embodied AI, spanning autonomous driving, robot manipulation, and navigation.
|
|
@@ -21,17 +34,19 @@ Whether you use the official OpenPi codebase or LeRobot’s reimplementation, yo
|
|
|
21
34
|
|
|
22
35
|
OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we also use it internally to train our proprietary in-house models. Our goal is to help you train VLAs on any dataset while fully leveraging state-of-the-art techniques. We plan to continuously upgrade this repository to keep pace with the state of the art in the robotics community.
|
|
23
36
|
|
|
24
|
-
|
|
|
25
|
-
|
|
26
|
-
|
|
|
27
|
-
|
|
|
28
|
-
| Knowledge Insulation (KI) between VLM and Action Decoder |
|
|
29
|
-
|
|
|
30
|
-
|
|
|
31
|
-
|
|
|
32
|
-
|
|
|
33
|
-
|
|
|
34
|
-
|
|
|
37
|
+
| Features | OpenPi | LeRobot | **OpenTau** |
|
|
38
|
+
|---------------------------------------------------------:|:-----------------------:|:--------------------------------:|:-----------:|
|
|
39
|
+
| Co-training with Heterogeneous Datasets | ❌ | ❌ | ✅ |
|
|
40
|
+
| Discrete Actions Training in $\pi_{0.5}$ | ❌ | ❌ | ✅ |
|
|
41
|
+
| Knowledge Insulation (KI) between VLM and Action Decoder | ❌ | ❌ | ✅ |
|
|
42
|
+
| Dropout Layers in PaliGemma | ✅ (Jax) <br>❌ (PyTorch) | ❌ | ✅ |
|
|
43
|
+
| Multi-Node and Multi-GPU Training | ❌ | ✅ | ✅ |
|
|
44
|
+
| Fully Functioning $\pi_{0.5}$ Checkpoint | ✅ | ❌ <br> (Missing Text Embeddings) | ✅ |
|
|
45
|
+
| Visualize dataset with URDF models | ❌ | ❌ | ✅ |
|
|
46
|
+
| Simulation Environments for Evaluating Models | ❌ | ✅ | ✅ |
|
|
47
|
+
| Create Validation Splits During Training | ❌ | ❌ | ✅ |
|
|
48
|
+
| $\pi^{*}_{0.6}$ style Reinforcement Learning Pipeline | ❌ | ❌ | ✅ |
|
|
49
|
+
| Framework | Jax / PyTorch | PyTorch | PyTorch |
|
|
35
50
|
|
|
36
51
|
## Quick Start
|
|
37
52
|
If you are familiar with LeRobot, getting started with OpenTau is very easy.
|
|
@@ -20,13 +20,12 @@ huggingface = "https://huggingface.co/TensorAuto"
|
|
|
20
20
|
|
|
21
21
|
[project]
|
|
22
22
|
name = "opentau"
|
|
23
|
-
version = "0.
|
|
23
|
+
version = "0.2.0"
|
|
24
24
|
description = "OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch"
|
|
25
25
|
authors = [
|
|
26
26
|
{ name = "Shuheng Liu", email = "wish1104@icloud.com" },
|
|
27
27
|
{ name = "William Yue", email = "williamyue37@gmail.com" },
|
|
28
28
|
{ name = "Akshay Shah", email = "akshayhitendrashah@gmail.com" },
|
|
29
|
-
{ name = "Xingrui Gu", email = "xingrui_gu@berkeley.edu" }
|
|
30
29
|
]
|
|
31
30
|
readme = "README.md"
|
|
32
31
|
license = { text = "Apache-2.0" }
|
|
@@ -65,9 +64,9 @@ dependencies = [
|
|
|
65
64
|
"pyzmq>=26.2.1",
|
|
66
65
|
"rerun-sdk>=0.21.0",
|
|
67
66
|
"termcolor>=2.4.0",
|
|
68
|
-
"torch>=2.7.1
|
|
67
|
+
"torch>=2.7.1",
|
|
69
68
|
"torchcodec>=0.4.0, <0.5.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
|
70
|
-
"torchvision>=0.22.1
|
|
69
|
+
"torchvision>=0.22.1",
|
|
71
70
|
"wandb>=0.16.3",
|
|
72
71
|
"zarr>=2.17.0",
|
|
73
72
|
"scikit-learn>=1.7.1",
|
|
@@ -76,7 +75,7 @@ dependencies = [
|
|
|
76
75
|
"onnxruntime-gpu>=1.22.0 ; ((sys_platform == 'linux' and platform_machine == 'x86_64') or (sys_platform == 'win32' and (platform_machine == 'AMD64' or platform_machine == 'x86_64'))) ",
|
|
77
76
|
"onnxscript>=0.3.1",
|
|
78
77
|
"onnx-ir>=0.1.4",
|
|
79
|
-
"
|
|
78
|
+
"transformers==4.53.3",
|
|
80
79
|
"scipy>=1.15.2",
|
|
81
80
|
"pytest>=8.1.0",
|
|
82
81
|
"pytest-cov>=5.0.0",
|
|
@@ -85,13 +84,18 @@ dependencies = [
|
|
|
85
84
|
"scikit-image>=0.23.2",
|
|
86
85
|
"pandas>=2.2.2",
|
|
87
86
|
"accelerate>=1.4.0",
|
|
88
|
-
"deepspeed>=0.17.1"
|
|
87
|
+
"deepspeed>=0.17.1",
|
|
88
|
+
"gymnasium[other]>=0.29",
|
|
89
|
+
"grpcio>=1.60.0",
|
|
90
|
+
"grpcio-tools>=1.60.0",
|
|
91
|
+
"protobuf>=4.25.0",
|
|
89
92
|
]
|
|
90
93
|
|
|
91
94
|
[project.scripts]
|
|
92
95
|
opentau-train = "opentau.scripts.launch:train"
|
|
93
96
|
opentau-eval = "opentau.scripts.launch:eval"
|
|
94
97
|
opentau-export = "opentau.scripts.launch:export"
|
|
98
|
+
opentau-dataset-viz = "opentau.scripts.launch:visualize"
|
|
95
99
|
|
|
96
100
|
[project.optional-dependencies]
|
|
97
101
|
dev = ["pre-commit>=3.7.0",
|
|
@@ -123,15 +127,26 @@ libero = [
|
|
|
123
127
|
"numpy<2",
|
|
124
128
|
"gym>=0.25,<0.27",
|
|
125
129
|
"pyopengl-accelerate==3.1.7 ; sys_platform == 'linux'",
|
|
126
|
-
"gymnasium[other]>=0.29",
|
|
127
130
|
"mujoco>=3.1.6 ; sys_platform == 'linux'",
|
|
128
131
|
"pyopengl==3.1.7 ; sys_platform == 'linux'",
|
|
129
132
|
"numpy==1.26.4 ; sys_platform == 'linux'",
|
|
130
133
|
]
|
|
134
|
+
urdf = [
|
|
135
|
+
"rerun-sdk>=0.28.2",
|
|
136
|
+
]
|
|
131
137
|
|
|
132
138
|
[tool.uv.sources]
|
|
133
139
|
libero = { git = "https://github.com/shuheng-liu/LIBERO" , branch = "master" } # the official libero repo is misconfigured for pip install with git
|
|
134
140
|
|
|
141
|
+
# libero depends on gym, which depends on numpy 1.x, while rerun only supports urdf in v0.28 which requires numpy 2.x
|
|
142
|
+
[tool.uv]
|
|
143
|
+
conflicts = [
|
|
144
|
+
[
|
|
145
|
+
{ extra = "libero" },
|
|
146
|
+
{ extra = "urdf" },
|
|
147
|
+
],
|
|
148
|
+
]
|
|
149
|
+
|
|
135
150
|
[tool.setuptools.packages.find]
|
|
136
151
|
where = ["src"]
|
|
137
152
|
|
|
@@ -142,6 +157,10 @@ target-version = "py310"
|
|
|
142
157
|
[tool.ruff.lint]
|
|
143
158
|
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
|
144
159
|
|
|
160
|
+
[tool.ruff.lint.per-file-ignores]
|
|
161
|
+
# Server must implement gRPC interface with PascalCase method names
|
|
162
|
+
"src/opentau/scripts/grpc/server.py" = ["N802"]
|
|
163
|
+
|
|
145
164
|
[tool.bandit]
|
|
146
165
|
exclude_dirs = [
|
|
147
166
|
"tests",
|
|
@@ -96,6 +96,11 @@ class DatasetConfig:
|
|
|
96
96
|
data_features_name_mapping: dict[str, str] | None = None
|
|
97
97
|
loss_type_mapping: str | None = None
|
|
98
98
|
|
|
99
|
+
# Ratio of the dataset to be used for validation. Please specify a value.
|
|
100
|
+
# If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
|
|
101
|
+
# Defaults to 0.05.
|
|
102
|
+
val_split_ratio: float = 0.05
|
|
103
|
+
|
|
99
104
|
def __post_init__(self):
|
|
100
105
|
"""Validate dataset configuration and register custom mappings if provided."""
|
|
101
106
|
if (self.repo_id is None) == (self.grounding is None):
|
|
@@ -148,6 +153,11 @@ class DatasetMixtureConfig:
|
|
|
148
153
|
image_resample_strategy: str = "nearest"
|
|
149
154
|
# Resample strategy for non-image features, such as action or state
|
|
150
155
|
vector_resample_strategy: str = "nearest"
|
|
156
|
+
# Ratio of the dataset to be used for validation. Please specify a value.
|
|
157
|
+
# If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
|
|
158
|
+
# This value is applied to all datasets in the mixture.
|
|
159
|
+
# Defaults to 0.05.
|
|
160
|
+
val_split_ratio: float = 0.05
|
|
151
161
|
|
|
152
162
|
def __post_init__(self):
|
|
153
163
|
"""Validate dataset mixture configuration."""
|
|
@@ -163,6 +173,12 @@ class DatasetMixtureConfig:
|
|
|
163
173
|
raise ValueError(
|
|
164
174
|
f"`vector_resample_strategy` must be one of ['linear', 'nearest'], got {self.vector_resample_strategy}."
|
|
165
175
|
)
|
|
176
|
+
if self.val_split_ratio < 0 or self.val_split_ratio > 1:
|
|
177
|
+
raise ValueError(f"`val_split_ratio` must be between 0 and 1, got {self.val_split_ratio}.")
|
|
178
|
+
|
|
179
|
+
# set the val_split_ratio for all datasets in the mixture
|
|
180
|
+
for dataset_cfg in self.datasets:
|
|
181
|
+
dataset_cfg.val_split_ratio = self.val_split_ratio
|
|
166
182
|
|
|
167
183
|
|
|
168
184
|
@dataclass
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Deployment configuration classes for inference servers.
|
|
15
|
+
|
|
16
|
+
This module provides configuration classes for deploying trained models
|
|
17
|
+
as inference servers, including gRPC server settings.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ServerConfig:
|
|
25
|
+
"""Configuration for the gRPC inference server.
|
|
26
|
+
|
|
27
|
+
This class contains all configuration parameters needed to run a gRPC
|
|
28
|
+
inference server for robot policy models.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
port: Port number to serve on. Must be between 1 and 65535.
|
|
32
|
+
Defaults to 50051.
|
|
33
|
+
max_workers: Maximum number of gRPC worker threads for handling
|
|
34
|
+
concurrent requests. Defaults to 4.
|
|
35
|
+
max_send_message_length_mb: Maximum size of outgoing messages in
|
|
36
|
+
megabytes. Defaults to 100.
|
|
37
|
+
max_receive_message_length_mb: Maximum size of incoming messages in
|
|
38
|
+
megabytes. Defaults to 100.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If port is not in valid range or max_workers is less than 1.
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> config = ServerConfig(port=50051, max_workers=8)
|
|
45
|
+
>>> config.port
|
|
46
|
+
50051
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
port: int = 50051
|
|
50
|
+
max_workers: int = 4
|
|
51
|
+
max_send_message_length_mb: int = 100
|
|
52
|
+
max_receive_message_length_mb: int = 100
|
|
53
|
+
|
|
54
|
+
def __post_init__(self):
|
|
55
|
+
"""Validate server configuration parameters."""
|
|
56
|
+
if not 1 <= self.port <= 65535:
|
|
57
|
+
raise ValueError(f"`port` must be between 1 and 65535, got {self.port}.")
|
|
58
|
+
if self.max_workers < 1:
|
|
59
|
+
raise ValueError(f"`max_workers` must be at least 1, got {self.max_workers}.")
|
|
60
|
+
if self.max_send_message_length_mb < 1:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"`max_send_message_length_mb` must be at least 1, got {self.max_send_message_length_mb}."
|
|
63
|
+
)
|
|
64
|
+
if self.max_receive_message_length_mb < 1:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"`max_receive_message_length_mb` must be at least 1, got {self.max_receive_message_length_mb}."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def max_send_message_length(self) -> int:
|
|
71
|
+
"""Get maximum send message length in bytes.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Maximum send message length in bytes.
|
|
75
|
+
"""
|
|
76
|
+
return self.max_send_message_length_mb * 1024 * 1024
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def max_receive_message_length(self) -> int:
|
|
80
|
+
"""Get maximum receive message length in bytes.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Maximum receive message length in bytes.
|
|
84
|
+
"""
|
|
85
|
+
return self.max_receive_message_length_mb * 1024 * 1024
|
|
@@ -32,6 +32,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
|
|
32
32
|
|
|
33
33
|
from opentau.configs import parser
|
|
34
34
|
from opentau.configs.default import DatasetMixtureConfig, EvalConfig, WandBConfig
|
|
35
|
+
from opentau.configs.deployment import ServerConfig
|
|
35
36
|
from opentau.configs.policies import PreTrainedConfig
|
|
36
37
|
from opentau.envs.configs import EnvConfig
|
|
37
38
|
from opentau.optim import OptimizerConfig
|
|
@@ -116,6 +117,7 @@ class TrainPipelineConfig(HubMixin):
|
|
|
116
117
|
is disabled. Defaults to 0.
|
|
117
118
|
last_checkpoint_only: If True, only evaluate the last checkpoint.
|
|
118
119
|
Defaults to True.
|
|
120
|
+
server: Configuration for the gRPC inference server. Defaults to ServerConfig().
|
|
119
121
|
"""
|
|
120
122
|
|
|
121
123
|
dataset_mixture: DatasetMixtureConfig
|
|
@@ -163,7 +165,10 @@ class TrainPipelineConfig(HubMixin):
|
|
|
163
165
|
env: EnvConfig | None = None
|
|
164
166
|
eval: EvalConfig | None = field(default_factory=EvalConfig)
|
|
165
167
|
eval_freq: int = 0 # evaluate every eval_freq steps
|
|
168
|
+
val_freq: int = 0 # validate every val_freq steps, if 0, then a validation split is not created
|
|
166
169
|
last_checkpoint_only: bool = True
|
|
170
|
+
# gRPC inference server configuration
|
|
171
|
+
server: ServerConfig = field(default_factory=ServerConfig)
|
|
167
172
|
|
|
168
173
|
def __post_init__(self):
|
|
169
174
|
"""Initialize post-creation attributes and validate batch size configuration."""
|
|
@@ -61,7 +61,11 @@ Example:
|
|
|
61
61
|
>>> dataloader = mixture.get_dataloader()
|
|
62
62
|
"""
|
|
63
63
|
|
|
64
|
+
import copy
|
|
65
|
+
from typing import Tuple, Union
|
|
66
|
+
|
|
64
67
|
import numpy as np
|
|
68
|
+
import torch
|
|
65
69
|
|
|
66
70
|
# NOTE: Don't delete; imported for side effects.
|
|
67
71
|
import opentau.datasets.grounding.clevr # noqa: F401
|
|
@@ -151,9 +155,13 @@ def make_dataset(
|
|
|
151
155
|
cfg: DatasetConfig,
|
|
152
156
|
train_cfg: TrainPipelineConfig,
|
|
153
157
|
return_advantage_input: bool = False,
|
|
154
|
-
) -> BaseDataset:
|
|
158
|
+
) -> Union[BaseDataset, Tuple[BaseDataset, BaseDataset]]:
|
|
155
159
|
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
|
156
160
|
|
|
161
|
+
A train and validation dataset are returned if `train_cfg.val_freq` is greater than 0.
|
|
162
|
+
The validation dataset is a subset of the train dataset, and is used for evaluation during training.
|
|
163
|
+
The validation dataset is created by splitting the train dataset into train and validation sets based on `cfg.val_split_ratio`.
|
|
164
|
+
|
|
157
165
|
Args:
|
|
158
166
|
cfg (DatasetConfig): A DatasetConfig used to create a LeRobotDataset.
|
|
159
167
|
train_cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
|
@@ -161,10 +169,11 @@ def make_dataset(
|
|
|
161
169
|
"episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
|
|
162
170
|
|
|
163
171
|
Raises:
|
|
164
|
-
|
|
172
|
+
ValueError: If exactly one of `cfg.grounding` and `cfg.repo_id` is not provided.
|
|
173
|
+
ValueError: If `cfg.grounding` is not a supported grounding dataset.
|
|
165
174
|
|
|
166
175
|
Returns:
|
|
167
|
-
BaseDataset
|
|
176
|
+
BaseDataset or Tuple[BaseDataset, BaseDataset]: A single dataset or a tuple of (train_dataset, val_dataset) if val_freq > 0.
|
|
168
177
|
"""
|
|
169
178
|
image_transforms = ImageTransforms(cfg.image_transforms) if cfg.image_transforms.enable else None
|
|
170
179
|
|
|
@@ -209,12 +218,20 @@ def make_dataset(
|
|
|
209
218
|
dataset.meta.stats[key] = {}
|
|
210
219
|
dataset.meta.stats[key][stats_type] = np.array(stats, dtype=np.float32)
|
|
211
220
|
|
|
221
|
+
if train_cfg.val_freq > 0:
|
|
222
|
+
val_size = int(len(dataset) * cfg.val_split_ratio)
|
|
223
|
+
train_size = len(dataset) - val_size
|
|
224
|
+
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
|
225
|
+
train_dataset.meta = copy.deepcopy(dataset.meta)
|
|
226
|
+
val_dataset.meta = copy.deepcopy(dataset.meta)
|
|
227
|
+
return train_dataset, val_dataset
|
|
228
|
+
|
|
212
229
|
return dataset
|
|
213
230
|
|
|
214
231
|
|
|
215
232
|
def make_dataset_mixture(
|
|
216
233
|
cfg: TrainPipelineConfig, return_advantage_input: bool = False
|
|
217
|
-
) -> WeightedDatasetMixture:
|
|
234
|
+
) -> Union[WeightedDatasetMixture, Tuple[WeightedDatasetMixture, WeightedDatasetMixture]]:
|
|
218
235
|
"""Creates a dataset mixture from the provided TrainPipelineConfig.
|
|
219
236
|
|
|
220
237
|
Args:
|
|
@@ -223,10 +240,26 @@ def make_dataset_mixture(
|
|
|
223
240
|
"episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
|
|
224
241
|
|
|
225
242
|
Returns:
|
|
226
|
-
WeightedDatasetMixture: An instance of WeightedDatasetMixture containing the datasets.
|
|
243
|
+
WeightedDatasetMixture or Tuple[WeightedDatasetMixture, WeightedDatasetMixture]: An instance of WeightedDatasetMixture containing the datasets, or a tuple of (train_mixture, val_mixture) if val_freq > 0.
|
|
227
244
|
"""
|
|
228
|
-
datasets = [
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
245
|
+
datasets = []
|
|
246
|
+
val_datasets = []
|
|
247
|
+
for dataset_cfg in cfg.dataset_mixture.datasets:
|
|
248
|
+
res = make_dataset(dataset_cfg, cfg, return_advantage_input=return_advantage_input)
|
|
249
|
+
if isinstance(res, tuple):
|
|
250
|
+
datasets.append(res[0])
|
|
251
|
+
val_datasets.append(res[1])
|
|
252
|
+
else:
|
|
253
|
+
datasets.append(res)
|
|
254
|
+
|
|
255
|
+
train_mixture = WeightedDatasetMixture(
|
|
256
|
+
cfg, datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if val_datasets:
|
|
260
|
+
val_mixture = WeightedDatasetMixture(
|
|
261
|
+
cfg, val_datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq
|
|
262
|
+
)
|
|
263
|
+
return train_mixture, val_mixture
|
|
264
|
+
|
|
265
|
+
return train_mixture
|
|
@@ -150,6 +150,7 @@ from opentau.policies.value.configuration_value import ValueConfig
|
|
|
150
150
|
from opentau.policies.value.reward import (
|
|
151
151
|
calculate_return_bins_with_equal_width,
|
|
152
152
|
)
|
|
153
|
+
from opentau.utils.accelerate_utils import get_proc_accelerator
|
|
153
154
|
from opentau.utils.utils import on_accelerate_main_proc
|
|
154
155
|
|
|
155
156
|
|
|
@@ -324,8 +325,17 @@ class LeRobotDatasetMetadata(DatasetMetadata):
|
|
|
324
325
|
if is_valid_version(self.revision):
|
|
325
326
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
|
326
327
|
|
|
327
|
-
|
|
328
|
-
|
|
328
|
+
# In distributed training, only rank 0 downloads to avoid race conditions
|
|
329
|
+
# where other ranks read metadata before the download has finished.
|
|
330
|
+
acc = get_proc_accelerator()
|
|
331
|
+
if acc is not None and acc.num_processes > 1:
|
|
332
|
+
if acc.is_main_process:
|
|
333
|
+
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
|
334
|
+
self.pull_from_repo(allow_patterns="meta/")
|
|
335
|
+
acc.wait_for_everyone()
|
|
336
|
+
else:
|
|
337
|
+
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
|
338
|
+
self.pull_from_repo(allow_patterns="meta/")
|
|
329
339
|
self.load_metadata()
|
|
330
340
|
|
|
331
341
|
def load_metadata(self) -> None:
|
|
@@ -633,7 +643,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
|
633
643
|
For example, {"image_key": torch.zeros(2, 3, 224, 224), "image_key_is_pad": [False, True] } will become
|
|
634
644
|
{
|
|
635
645
|
"image_key": torch.zeros(3, 224, 224),
|
|
646
|
+
"image_key_local": torch.zeros(3, 224, 224),
|
|
636
647
|
"image_key_is_pad: False,
|
|
648
|
+
"image_key_local_is_pad": True,
|
|
637
649
|
}.
|
|
638
650
|
"""
|
|
639
651
|
raise NotImplementedError
|
|
@@ -723,14 +735,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
|
723
735
|
if isinstance(value, torch.Tensor) and value.dtype.is_floating_point:
|
|
724
736
|
standard_item[key] = value.to(dtype=torch.bfloat16)
|
|
725
737
|
|
|
726
|
-
# ensure that non-empty strings contain exactly one newline character at the end of the string
|
|
727
|
-
for key in ["prompt", "response"]:
|
|
728
|
-
if standard_item[key].endswith(
|
|
729
|
-
"\n"
|
|
730
|
-
): # ensure there isn't going to be an extra space at the end after calling replace
|
|
731
|
-
standard_item[key] = standard_item[key][:-1]
|
|
732
|
-
standard_item[key] = standard_item[key].replace("\n", " ") + "\n"
|
|
733
|
-
|
|
734
738
|
return standard_item
|
|
735
739
|
|
|
736
740
|
def resize_with_pad(self, img, width, height, pad_value=0) -> torch.Tensor:
|
|
@@ -1787,16 +1791,12 @@ class LeRobotDataset(BaseDataset):
|
|
|
1787
1791
|
cam_keys = {v for k, v in name_map.items() if k.startswith("camera")}
|
|
1788
1792
|
for k in cam_keys:
|
|
1789
1793
|
images = item.pop(k)
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
)
|
|
1793
|
-
item[k + "_local"], item[k] = images
|
|
1794
|
+
if len(images) == 2:
|
|
1795
|
+
item[k + "_local"], item[k] = images
|
|
1794
1796
|
|
|
1795
|
-
pads = item.
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
)
|
|
1799
|
-
item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
|
|
1797
|
+
pads = item.get(k + "_is_pad")
|
|
1798
|
+
if hasattr(pads, "__len__") and len(pads) == 2:
|
|
1799
|
+
item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
|
|
1800
1800
|
|
|
1801
1801
|
@staticmethod
|
|
1802
1802
|
def compute_delta_params(
|
|
@@ -108,6 +108,7 @@ import pyarrow as pa
|
|
|
108
108
|
import torch
|
|
109
109
|
import torchvision
|
|
110
110
|
from datasets.features.features import register_feature
|
|
111
|
+
from packaging import version
|
|
111
112
|
from PIL import Image
|
|
112
113
|
|
|
113
114
|
|
|
@@ -117,13 +118,17 @@ def get_safe_default_codec() -> str:
|
|
|
117
118
|
Returns:
|
|
118
119
|
Backend name: "torchcodec" if available, otherwise "pyav".
|
|
119
120
|
"""
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
else:
|
|
123
|
-
logging.warning(
|
|
124
|
-
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
|
125
|
-
)
|
|
121
|
+
|
|
122
|
+
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
|
126
123
|
return "pyav"
|
|
124
|
+
else:
|
|
125
|
+
if importlib.util.find_spec("torchcodec"):
|
|
126
|
+
return "torchcodec"
|
|
127
|
+
else:
|
|
128
|
+
logging.warning(
|
|
129
|
+
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
|
130
|
+
)
|
|
131
|
+
return "pyav"
|
|
127
132
|
|
|
128
133
|
|
|
129
134
|
def decode_video_frames(
|