opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,829 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
|
20
|
+
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
|
21
|
+
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
|
|
22
|
+
|
|
23
|
+
We support 3 different scenarios for these tasks (see instructions below):
|
|
24
|
+
1. Single task dataset: all episodes of your dataset have the same single task.
|
|
25
|
+
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
|
26
|
+
one episode to the next.
|
|
27
|
+
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
|
|
31
|
+
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
|
|
32
|
+
recorded with. For now, only Aloha/Koch type robots are supported with this option.
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# 1. Single task dataset
|
|
36
|
+
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
|
37
|
+
'--single-task' option.
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \
|
|
43
|
+
--repo-id lerobot/aloha_sim_insertion_human_image \
|
|
44
|
+
--single-task "Insert the peg into the socket." \
|
|
45
|
+
--robot-config lerobot/configs/robot/aloha.yaml \
|
|
46
|
+
--local-dir data
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \
|
|
51
|
+
--repo-id aliberts/koch_tutorial \
|
|
52
|
+
--single-task "Pick the Lego block and drop it in the box on the right." \
|
|
53
|
+
--robot-config lerobot/configs/robot/koch.yaml \
|
|
54
|
+
--local-dir data
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# 2. Single task episodes
|
|
59
|
+
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
|
|
60
|
+
|
|
61
|
+
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
|
|
62
|
+
this column's name with the '--tasks-col' arg.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
|
|
66
|
+
```bash
|
|
67
|
+
python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \
|
|
68
|
+
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
|
69
|
+
--tasks-col "language_instruction" \
|
|
70
|
+
--local-dir data
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
|
|
74
|
+
'--tasks-path' arg. This file should have the following structure where keys correspond to each
|
|
75
|
+
episode_index in the dataset, and values are the language instruction for that episode.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
|
|
79
|
+
```json
|
|
80
|
+
{
|
|
81
|
+
"0": "Do something",
|
|
82
|
+
"1": "Do something else",
|
|
83
|
+
"2": "Do something",
|
|
84
|
+
"3": "Go there",
|
|
85
|
+
...
|
|
86
|
+
}
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
# 3. Multi task episodes
|
|
90
|
+
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
|
91
|
+
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
|
92
|
+
|
|
93
|
+
Example:
|
|
94
|
+
|
|
95
|
+
```bash
|
|
96
|
+
python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \
|
|
97
|
+
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
|
98
|
+
--tasks-col "language_instruction" \
|
|
99
|
+
--local-dir data
|
|
100
|
+
```
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
import argparse
|
|
104
|
+
import contextlib
|
|
105
|
+
import filecmp
|
|
106
|
+
import json
|
|
107
|
+
import logging
|
|
108
|
+
import math
|
|
109
|
+
import shutil
|
|
110
|
+
import subprocess
|
|
111
|
+
import tempfile
|
|
112
|
+
from pathlib import Path
|
|
113
|
+
|
|
114
|
+
import datasets
|
|
115
|
+
import pyarrow.compute as pc
|
|
116
|
+
import pyarrow.parquet as pq
|
|
117
|
+
import torch
|
|
118
|
+
from datasets import Dataset
|
|
119
|
+
from huggingface_hub import HfApi
|
|
120
|
+
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
|
121
|
+
from safetensors.torch import load_file
|
|
122
|
+
|
|
123
|
+
from opentau.datasets.utils import (
|
|
124
|
+
DEFAULT_CHUNK_SIZE,
|
|
125
|
+
DEFAULT_PARQUET_PATH,
|
|
126
|
+
DEFAULT_VIDEO_PATH,
|
|
127
|
+
EPISODES_PATH,
|
|
128
|
+
INFO_PATH,
|
|
129
|
+
STATS_PATH,
|
|
130
|
+
TASKS_PATH,
|
|
131
|
+
create_branch,
|
|
132
|
+
create_lerobot_dataset_card,
|
|
133
|
+
flatten_dict,
|
|
134
|
+
get_safe_version,
|
|
135
|
+
load_json,
|
|
136
|
+
unflatten_dict,
|
|
137
|
+
write_json,
|
|
138
|
+
write_jsonlines,
|
|
139
|
+
)
|
|
140
|
+
from opentau.datasets.video_utils import (
|
|
141
|
+
VideoFrame, # noqa: F401
|
|
142
|
+
get_image_pixel_channels,
|
|
143
|
+
get_video_info,
|
|
144
|
+
)
|
|
145
|
+
from opentau.robot_devices.robots.configs import RobotConfig
|
|
146
|
+
from opentau.robot_devices.robots.utils import make_robot_config
|
|
147
|
+
|
|
148
|
+
V16 = "v1.6"
|
|
149
|
+
V20 = "v2.0"
|
|
150
|
+
|
|
151
|
+
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
|
152
|
+
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
|
153
|
+
V1_INFO_PATH = "meta_data/info.json"
|
|
154
|
+
V1_STATS_PATH = "meta_data/stats.safetensors"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
|
158
|
+
"""Parse robot configuration to extract robot type and motor names.
|
|
159
|
+
|
|
160
|
+
Extracts state and action motor names from the robot configuration.
|
|
161
|
+
Currently supports "aloha" and "koch" robot types.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
robot_cfg: Robot configuration object.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Dictionary with 'robot_type' and 'names' keys. The 'names' dictionary
|
|
168
|
+
maps feature keys to lists of motor names.
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
NotImplementedError: If robot type is not supported.
|
|
172
|
+
"""
|
|
173
|
+
if robot_cfg.type in ["aloha", "koch"]:
|
|
174
|
+
state_names = [
|
|
175
|
+
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
|
|
176
|
+
for arm in robot_cfg.follower_arms
|
|
177
|
+
for motor in robot_cfg.follower_arms[arm].motors
|
|
178
|
+
]
|
|
179
|
+
action_names = [
|
|
180
|
+
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
|
181
|
+
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
|
|
182
|
+
for arm in robot_cfg.leader_arms
|
|
183
|
+
for motor in robot_cfg.leader_arms[arm].motors
|
|
184
|
+
]
|
|
185
|
+
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
|
186
|
+
else:
|
|
187
|
+
raise NotImplementedError(
|
|
188
|
+
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return {
|
|
192
|
+
"robot_type": robot_cfg.type,
|
|
193
|
+
"names": {
|
|
194
|
+
"observation.state": state_names,
|
|
195
|
+
"observation.effort": state_names,
|
|
196
|
+
"action": action_names,
|
|
197
|
+
},
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
|
202
|
+
"""Convert statistics from v1.6 safetensors format to v2.0 JSON format.
|
|
203
|
+
|
|
204
|
+
Loads statistics from safetensors file, converts to JSON, and validates
|
|
205
|
+
that the conversion preserves all values correctly.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
v1_dir: Directory containing v1.6 dataset with meta_data/stats.safetensors.
|
|
209
|
+
v2_dir: Directory where v2.0 dataset meta/stats.json will be written.
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
AssertionError: If converted statistics don't match original values.
|
|
213
|
+
"""
|
|
214
|
+
safetensor_path = v1_dir / V1_STATS_PATH
|
|
215
|
+
stats = load_file(safetensor_path)
|
|
216
|
+
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
|
217
|
+
serialized_stats = unflatten_dict(serialized_stats)
|
|
218
|
+
|
|
219
|
+
json_path = v2_dir / STATS_PATH
|
|
220
|
+
json_path.parent.mkdir(exist_ok=True, parents=True)
|
|
221
|
+
with open(json_path, "w") as f:
|
|
222
|
+
json.dump(serialized_stats, f, indent=4)
|
|
223
|
+
|
|
224
|
+
# Sanity check
|
|
225
|
+
with open(json_path) as f:
|
|
226
|
+
stats_json = json.load(f)
|
|
227
|
+
|
|
228
|
+
stats_json = flatten_dict(stats_json)
|
|
229
|
+
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
|
230
|
+
for key in stats:
|
|
231
|
+
torch.testing.assert_close(stats_json[key], stats[key])
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def get_features_from_hf_dataset(
|
|
235
|
+
dataset: Dataset, robot_config: RobotConfig | None = None
|
|
236
|
+
) -> dict[str, list]:
|
|
237
|
+
"""Extract feature specifications from a HuggingFace dataset.
|
|
238
|
+
|
|
239
|
+
Converts HuggingFace dataset features to the format expected by v2.0 datasets.
|
|
240
|
+
Handles Value, Sequence, Image, and VideoFrame feature types.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
dataset: HuggingFace dataset to extract features from.
|
|
244
|
+
robot_config: Optional robot configuration for motor name extraction.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
Dictionary mapping feature names to feature specifications with
|
|
248
|
+
'dtype', 'shape', and 'names' keys.
|
|
249
|
+
"""
|
|
250
|
+
robot_config = parse_robot_config(robot_config)
|
|
251
|
+
features = {}
|
|
252
|
+
for key, ft in dataset.features.items():
|
|
253
|
+
if isinstance(ft, datasets.Value):
|
|
254
|
+
dtype = ft.dtype
|
|
255
|
+
shape = (1,)
|
|
256
|
+
names = None
|
|
257
|
+
if isinstance(ft, datasets.Sequence):
|
|
258
|
+
assert isinstance(ft.feature, datasets.Value)
|
|
259
|
+
dtype = ft.feature.dtype
|
|
260
|
+
shape = (ft.length,)
|
|
261
|
+
motor_names = (
|
|
262
|
+
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
|
263
|
+
)
|
|
264
|
+
assert len(motor_names) == shape[0]
|
|
265
|
+
names = {"motors": motor_names}
|
|
266
|
+
elif isinstance(ft, datasets.Image):
|
|
267
|
+
dtype = "image"
|
|
268
|
+
image = dataset[0][key] # Assuming first row
|
|
269
|
+
channels = get_image_pixel_channels(image)
|
|
270
|
+
shape = (image.height, image.width, channels)
|
|
271
|
+
names = ["height", "width", "channels"]
|
|
272
|
+
elif ft._type == "VideoFrame":
|
|
273
|
+
dtype = "video"
|
|
274
|
+
shape = None # Add shape later
|
|
275
|
+
names = ["height", "width", "channels"]
|
|
276
|
+
|
|
277
|
+
features[key] = {
|
|
278
|
+
"dtype": dtype,
|
|
279
|
+
"shape": shape,
|
|
280
|
+
"names": names,
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
return features
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
|
287
|
+
"""Add task_index column to dataset based on episode-to-task mapping.
|
|
288
|
+
|
|
289
|
+
Creates a task_index column by mapping episode indices to task indices.
|
|
290
|
+
Also creates a tasks list mapping task_index to task description.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
dataset: HuggingFace dataset to modify.
|
|
294
|
+
tasks_by_episodes: Dictionary mapping episode_index to task description.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Tuple of (modified_dataset, tasks_list) where tasks_list maps
|
|
298
|
+
task_index to task description.
|
|
299
|
+
"""
|
|
300
|
+
df = dataset.to_pandas()
|
|
301
|
+
tasks = list(set(tasks_by_episodes.values()))
|
|
302
|
+
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
|
303
|
+
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
|
304
|
+
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
|
305
|
+
|
|
306
|
+
features = dataset.features
|
|
307
|
+
features["task_index"] = datasets.Value(dtype="int64")
|
|
308
|
+
dataset = Dataset.from_pandas(df, features=features, split="train")
|
|
309
|
+
return dataset, tasks
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def add_task_index_from_tasks_col(
|
|
313
|
+
dataset: Dataset, tasks_col: str
|
|
314
|
+
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
|
|
315
|
+
"""Add task_index column from an existing tasks column in the dataset.
|
|
316
|
+
|
|
317
|
+
Extracts tasks from the specified column, creates task indices, and removes
|
|
318
|
+
the original tasks column. Also handles cleaning of tensor string formats.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
dataset: HuggingFace dataset containing a tasks column.
|
|
322
|
+
tasks_col: Name of the column containing task descriptions.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
Tuple of (modified_dataset, tasks_list, tasks_by_episode):
|
|
326
|
+
- modified_dataset: Dataset with task_index column added and tasks_col removed.
|
|
327
|
+
- tasks_list: List of unique tasks.
|
|
328
|
+
- tasks_by_episode: Dictionary mapping episode_index to list of tasks.
|
|
329
|
+
"""
|
|
330
|
+
df = dataset.to_pandas()
|
|
331
|
+
|
|
332
|
+
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
|
333
|
+
prefix_to_clean = "tf.Tensor(b'"
|
|
334
|
+
suffix_to_clean = "', shape=(), dtype=string)"
|
|
335
|
+
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
|
336
|
+
|
|
337
|
+
# Create task_index col
|
|
338
|
+
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
|
339
|
+
tasks = df[tasks_col].unique().tolist()
|
|
340
|
+
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
|
341
|
+
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
|
342
|
+
|
|
343
|
+
# Build the dataset back from df
|
|
344
|
+
features = dataset.features
|
|
345
|
+
features["task_index"] = datasets.Value(dtype="int64")
|
|
346
|
+
dataset = Dataset.from_pandas(df, features=features, split="train")
|
|
347
|
+
dataset = dataset.remove_columns(tasks_col)
|
|
348
|
+
|
|
349
|
+
return dataset, tasks, tasks_by_episode
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def split_parquet_by_episodes(
|
|
353
|
+
dataset: Dataset,
|
|
354
|
+
total_episodes: int,
|
|
355
|
+
total_chunks: int,
|
|
356
|
+
output_dir: Path,
|
|
357
|
+
) -> list:
|
|
358
|
+
"""Split dataset into separate parquet files, one per episode.
|
|
359
|
+
|
|
360
|
+
Organizes episodes into chunks and writes each episode to its own parquet file
|
|
361
|
+
following the v2.0 directory structure.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
dataset: HuggingFace dataset to split.
|
|
365
|
+
total_episodes: Total number of episodes in the dataset.
|
|
366
|
+
total_chunks: Total number of chunks to organize episodes into.
|
|
367
|
+
output_dir: Root directory where parquet files will be written.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
List of episode lengths (one per episode).
|
|
371
|
+
"""
|
|
372
|
+
table = dataset.data.table
|
|
373
|
+
episode_lengths = []
|
|
374
|
+
for ep_chunk in range(total_chunks):
|
|
375
|
+
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
|
376
|
+
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
|
377
|
+
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
|
378
|
+
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
|
379
|
+
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
|
380
|
+
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
|
381
|
+
episode_lengths.insert(ep_idx, len(ep_table))
|
|
382
|
+
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
|
383
|
+
episode_chunk=ep_chunk, episode_index=ep_idx
|
|
384
|
+
)
|
|
385
|
+
pq.write_table(ep_table, output_file)
|
|
386
|
+
|
|
387
|
+
return episode_lengths
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def move_videos(
|
|
391
|
+
repo_id: str,
|
|
392
|
+
video_keys: list[str],
|
|
393
|
+
total_episodes: int,
|
|
394
|
+
total_chunks: int,
|
|
395
|
+
work_dir: Path,
|
|
396
|
+
clean_gittatributes: Path,
|
|
397
|
+
branch: str = "main",
|
|
398
|
+
) -> None:
|
|
399
|
+
"""Move video files from v1.6 flat structure to v2.0 chunked structure.
|
|
400
|
+
|
|
401
|
+
Uses git LFS to move video file references without downloading the actual files.
|
|
402
|
+
This is a workaround since HfApi doesn't support moving files directly.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
repo_id: Repository ID of the dataset.
|
|
406
|
+
video_keys: List of video feature keys to move.
|
|
407
|
+
total_episodes: Total number of episodes.
|
|
408
|
+
total_chunks: Total number of chunks.
|
|
409
|
+
work_dir: Working directory for git operations.
|
|
410
|
+
clean_gittatributes: Path to clean .gitattributes file.
|
|
411
|
+
branch: Git branch to work with. Defaults to "main".
|
|
412
|
+
|
|
413
|
+
Note:
|
|
414
|
+
This function uses git commands to manipulate LFS-tracked files without
|
|
415
|
+
downloading them, which is more efficient than using the HuggingFace API.
|
|
416
|
+
"""
|
|
417
|
+
_lfs_clone(repo_id, work_dir, branch)
|
|
418
|
+
|
|
419
|
+
videos_moved = False
|
|
420
|
+
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
|
421
|
+
if len(video_files) == 0:
|
|
422
|
+
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
|
423
|
+
videos_moved = True # Videos have already been moved
|
|
424
|
+
|
|
425
|
+
assert len(video_files) == total_episodes * len(video_keys)
|
|
426
|
+
|
|
427
|
+
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
|
428
|
+
|
|
429
|
+
current_gittatributes = work_dir / ".gitattributes"
|
|
430
|
+
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
|
|
431
|
+
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
|
|
432
|
+
|
|
433
|
+
if lfs_untracked_videos:
|
|
434
|
+
fix_lfs_video_files_tracking(work_dir, video_files)
|
|
435
|
+
|
|
436
|
+
if videos_moved:
|
|
437
|
+
return
|
|
438
|
+
|
|
439
|
+
video_dirs = sorted(work_dir.glob("videos*/"))
|
|
440
|
+
for ep_chunk in range(total_chunks):
|
|
441
|
+
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
|
442
|
+
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
|
443
|
+
for vid_key in video_keys:
|
|
444
|
+
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
|
|
445
|
+
episode_chunk=ep_chunk, video_key=vid_key
|
|
446
|
+
)
|
|
447
|
+
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
|
448
|
+
|
|
449
|
+
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
|
450
|
+
target_path = DEFAULT_VIDEO_PATH.format(
|
|
451
|
+
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
|
452
|
+
)
|
|
453
|
+
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
|
454
|
+
if len(video_dirs) == 1:
|
|
455
|
+
video_path = video_dirs[0] / video_file
|
|
456
|
+
else:
|
|
457
|
+
for dir in video_dirs:
|
|
458
|
+
if (dir / video_file).is_file():
|
|
459
|
+
video_path = dir / video_file
|
|
460
|
+
break
|
|
461
|
+
|
|
462
|
+
video_path.rename(work_dir / target_path)
|
|
463
|
+
|
|
464
|
+
commit_message = "Move video files into chunk subdirectories"
|
|
465
|
+
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
|
|
466
|
+
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
|
467
|
+
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
|
471
|
+
"""Fix Git LFS tracking for video files that aren't properly tracked.
|
|
472
|
+
|
|
473
|
+
Adds video files to .gitattributes to ensure they're tracked by Git LFS.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
work_dir: Working directory containing the repository.
|
|
477
|
+
lfs_untracked_videos: List of video file paths that need LFS tracking.
|
|
478
|
+
"""
|
|
479
|
+
"""
|
|
480
|
+
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
|
481
|
+
there's no other option than to download the actual files and reupload them with lfs tracking.
|
|
482
|
+
"""
|
|
483
|
+
for i in range(0, len(lfs_untracked_videos), 100):
|
|
484
|
+
files = lfs_untracked_videos[i : i + 100]
|
|
485
|
+
try:
|
|
486
|
+
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
|
487
|
+
except subprocess.CalledProcessError as e:
|
|
488
|
+
print("git rm --cached ERROR:")
|
|
489
|
+
print(e.stderr)
|
|
490
|
+
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
|
|
491
|
+
|
|
492
|
+
commit_message = "Track video files with git lfs"
|
|
493
|
+
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
|
494
|
+
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
|
498
|
+
"""Replace .gitattributes file with a clean version and commit the change.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
work_dir: Working directory containing the repository.
|
|
502
|
+
current_gittatributes: Path to current .gitattributes file.
|
|
503
|
+
clean_gittatributes: Path to clean .gitattributes file to use as replacement.
|
|
504
|
+
"""
|
|
505
|
+
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
|
506
|
+
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
|
507
|
+
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
|
508
|
+
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
|
512
|
+
"""Clone a repository with Git LFS support, skipping actual file downloads.
|
|
513
|
+
|
|
514
|
+
Initializes Git LFS and clones the repository without downloading LFS files
|
|
515
|
+
(using GIT_LFS_SKIP_SMUDGE=1) for efficiency.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
repo_id: Repository ID to clone.
|
|
519
|
+
work_dir: Directory where the repository will be cloned.
|
|
520
|
+
branch: Git branch to checkout.
|
|
521
|
+
"""
|
|
522
|
+
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
|
|
523
|
+
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
|
524
|
+
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
|
525
|
+
subprocess.run(
|
|
526
|
+
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
|
527
|
+
check=True,
|
|
528
|
+
env=env,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
|
533
|
+
"""Get list of video files that are not tracked by Git LFS.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
work_dir: Working directory containing the repository.
|
|
537
|
+
video_files: List of video file paths to check.
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
List of video file paths that are not tracked by Git LFS.
|
|
541
|
+
"""
|
|
542
|
+
lfs_tracked_files = subprocess.run(
|
|
543
|
+
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
|
544
|
+
)
|
|
545
|
+
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
|
546
|
+
return [f for f in video_files if f not in lfs_tracked_files]
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
|
550
|
+
"""Get video information (codec, dimensions, etc.) from the first episode.
|
|
551
|
+
|
|
552
|
+
Downloads video files from the first episode to extract metadata.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
repo_id: Repository ID of the dataset.
|
|
556
|
+
local_dir: Local directory where videos will be downloaded.
|
|
557
|
+
video_keys: List of video feature keys to get info for.
|
|
558
|
+
branch: Git branch to use.
|
|
559
|
+
|
|
560
|
+
Returns:
|
|
561
|
+
Dictionary mapping video keys to their information dictionaries.
|
|
562
|
+
"""
|
|
563
|
+
# Assumes first episode
|
|
564
|
+
video_files = [
|
|
565
|
+
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
|
566
|
+
for vid_key in video_keys
|
|
567
|
+
]
|
|
568
|
+
hub_api = HfApi()
|
|
569
|
+
hub_api.snapshot_download(
|
|
570
|
+
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
|
571
|
+
)
|
|
572
|
+
videos_info_dict = {}
|
|
573
|
+
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
|
574
|
+
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
|
575
|
+
|
|
576
|
+
return videos_info_dict
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def convert_dataset(
|
|
580
|
+
repo_id: str,
|
|
581
|
+
local_dir: Path,
|
|
582
|
+
single_task: str | None = None,
|
|
583
|
+
tasks_path: Path | None = None,
|
|
584
|
+
tasks_col: Path | None = None,
|
|
585
|
+
robot_config: RobotConfig | None = None,
|
|
586
|
+
test_branch: str | None = None,
|
|
587
|
+
**card_kwargs,
|
|
588
|
+
) -> None:
|
|
589
|
+
"""Convert a dataset from v1.6 format to v2.0 format.
|
|
590
|
+
|
|
591
|
+
Handles conversion of metadata, statistics, parquet files, videos, and tasks.
|
|
592
|
+
Supports three scenarios: single task dataset, single task per episode, or
|
|
593
|
+
multiple tasks per episode.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
repo_id: Repository ID of the dataset to convert.
|
|
597
|
+
local_dir: Local directory for downloading and writing converted dataset.
|
|
598
|
+
single_task: Single task description for datasets with one task.
|
|
599
|
+
Mutually exclusive with tasks_path and tasks_col.
|
|
600
|
+
tasks_path: Path to JSON file mapping episode_index to task descriptions.
|
|
601
|
+
Mutually exclusive with single_task and tasks_col.
|
|
602
|
+
tasks_col: Name of column in parquet files containing task descriptions.
|
|
603
|
+
Mutually exclusive with single_task and tasks_path.
|
|
604
|
+
robot_config: Optional robot configuration for extracting motor names.
|
|
605
|
+
test_branch: Optional branch name for testing conversion without affecting main.
|
|
606
|
+
**card_kwargs: Additional keyword arguments for dataset card creation.
|
|
607
|
+
|
|
608
|
+
Raises:
|
|
609
|
+
ValueError: If task specification arguments are invalid or missing.
|
|
610
|
+
"""
|
|
611
|
+
v1 = get_safe_version(repo_id, V16)
|
|
612
|
+
v1x_dir = local_dir / V16 / repo_id
|
|
613
|
+
v20_dir = local_dir / V20 / repo_id
|
|
614
|
+
v1x_dir.mkdir(parents=True, exist_ok=True)
|
|
615
|
+
v20_dir.mkdir(parents=True, exist_ok=True)
|
|
616
|
+
|
|
617
|
+
hub_api = HfApi()
|
|
618
|
+
hub_api.snapshot_download(
|
|
619
|
+
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
|
620
|
+
)
|
|
621
|
+
branch = "main"
|
|
622
|
+
if test_branch:
|
|
623
|
+
branch = test_branch
|
|
624
|
+
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
|
625
|
+
|
|
626
|
+
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
|
627
|
+
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
|
628
|
+
features = get_features_from_hf_dataset(dataset, robot_config)
|
|
629
|
+
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
|
630
|
+
|
|
631
|
+
if single_task and "language_instruction" in dataset.column_names:
|
|
632
|
+
logging.warning(
|
|
633
|
+
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
|
634
|
+
)
|
|
635
|
+
single_task = None
|
|
636
|
+
tasks_col = "language_instruction"
|
|
637
|
+
|
|
638
|
+
# Episodes & chunks
|
|
639
|
+
episode_indices = sorted(dataset.unique("episode_index"))
|
|
640
|
+
total_episodes = len(episode_indices)
|
|
641
|
+
assert episode_indices == list(range(total_episodes))
|
|
642
|
+
total_videos = total_episodes * len(video_keys)
|
|
643
|
+
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
|
644
|
+
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
|
645
|
+
total_chunks += 1
|
|
646
|
+
|
|
647
|
+
# Tasks
|
|
648
|
+
if single_task:
|
|
649
|
+
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
|
|
650
|
+
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
|
651
|
+
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
|
652
|
+
elif tasks_path:
|
|
653
|
+
tasks_by_episodes = load_json(tasks_path)
|
|
654
|
+
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
|
655
|
+
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
|
656
|
+
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
|
657
|
+
elif tasks_col:
|
|
658
|
+
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
|
659
|
+
else:
|
|
660
|
+
raise ValueError
|
|
661
|
+
|
|
662
|
+
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
|
663
|
+
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
|
664
|
+
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
|
665
|
+
features["task_index"] = {
|
|
666
|
+
"dtype": "int64",
|
|
667
|
+
"shape": (1,),
|
|
668
|
+
"names": None,
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
# Videos
|
|
672
|
+
if video_keys:
|
|
673
|
+
assert metadata_v1.get("video", False)
|
|
674
|
+
dataset = dataset.remove_columns(video_keys)
|
|
675
|
+
clean_gitattr = Path(
|
|
676
|
+
hub_api.hf_hub_download(
|
|
677
|
+
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
|
678
|
+
)
|
|
679
|
+
).absolute()
|
|
680
|
+
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
|
681
|
+
move_videos(
|
|
682
|
+
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
|
683
|
+
)
|
|
684
|
+
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
|
685
|
+
for key in video_keys:
|
|
686
|
+
features[key]["shape"] = (
|
|
687
|
+
videos_info[key].pop("video.height"),
|
|
688
|
+
videos_info[key].pop("video.width"),
|
|
689
|
+
videos_info[key].pop("video.channels"),
|
|
690
|
+
)
|
|
691
|
+
features[key]["video_info"] = videos_info[key]
|
|
692
|
+
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
|
693
|
+
if "encoding" in metadata_v1:
|
|
694
|
+
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
|
695
|
+
else:
|
|
696
|
+
assert metadata_v1.get("video", 0) == 0
|
|
697
|
+
videos_info = None
|
|
698
|
+
|
|
699
|
+
# Split data into 1 parquet file by episode
|
|
700
|
+
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
|
701
|
+
|
|
702
|
+
if robot_config is not None:
|
|
703
|
+
robot_type = robot_config.type
|
|
704
|
+
repo_tags = [robot_type]
|
|
705
|
+
else:
|
|
706
|
+
robot_type = "unknown"
|
|
707
|
+
repo_tags = None
|
|
708
|
+
|
|
709
|
+
# Episodes
|
|
710
|
+
episodes = [
|
|
711
|
+
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
|
712
|
+
for ep_idx in episode_indices
|
|
713
|
+
]
|
|
714
|
+
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
|
715
|
+
|
|
716
|
+
# Assemble metadata v2.0
|
|
717
|
+
metadata_v2_0 = {
|
|
718
|
+
"codebase_version": V20,
|
|
719
|
+
"robot_type": robot_type,
|
|
720
|
+
"total_episodes": total_episodes,
|
|
721
|
+
"total_frames": len(dataset),
|
|
722
|
+
"total_tasks": len(tasks),
|
|
723
|
+
"total_videos": total_videos,
|
|
724
|
+
"total_chunks": total_chunks,
|
|
725
|
+
"chunks_size": DEFAULT_CHUNK_SIZE,
|
|
726
|
+
"fps": metadata_v1["fps"],
|
|
727
|
+
"splits": {"train": f"0:{total_episodes}"},
|
|
728
|
+
"data_path": DEFAULT_PARQUET_PATH,
|
|
729
|
+
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
|
730
|
+
"features": features,
|
|
731
|
+
}
|
|
732
|
+
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
|
733
|
+
convert_stats_to_json(v1x_dir, v20_dir)
|
|
734
|
+
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
|
735
|
+
|
|
736
|
+
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
|
737
|
+
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
|
738
|
+
|
|
739
|
+
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
|
740
|
+
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
|
741
|
+
|
|
742
|
+
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
|
743
|
+
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
|
744
|
+
|
|
745
|
+
hub_api.upload_folder(
|
|
746
|
+
repo_id=repo_id,
|
|
747
|
+
path_in_repo="data",
|
|
748
|
+
folder_path=v20_dir / "data",
|
|
749
|
+
repo_type="dataset",
|
|
750
|
+
revision=branch,
|
|
751
|
+
)
|
|
752
|
+
hub_api.upload_folder(
|
|
753
|
+
repo_id=repo_id,
|
|
754
|
+
path_in_repo="meta",
|
|
755
|
+
folder_path=v20_dir / "meta",
|
|
756
|
+
repo_type="dataset",
|
|
757
|
+
revision=branch,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
|
761
|
+
|
|
762
|
+
if not test_branch:
|
|
763
|
+
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def main():
|
|
767
|
+
parser = argparse.ArgumentParser()
|
|
768
|
+
task_args = parser.add_mutually_exclusive_group(required=True)
|
|
769
|
+
|
|
770
|
+
parser.add_argument(
|
|
771
|
+
"--repo-id",
|
|
772
|
+
type=str,
|
|
773
|
+
required=True,
|
|
774
|
+
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
|
775
|
+
)
|
|
776
|
+
task_args.add_argument(
|
|
777
|
+
"--single-task",
|
|
778
|
+
type=str,
|
|
779
|
+
help="A short but accurate description of the single task performed in the dataset.",
|
|
780
|
+
)
|
|
781
|
+
task_args.add_argument(
|
|
782
|
+
"--tasks-col",
|
|
783
|
+
type=str,
|
|
784
|
+
help="The name of the column containing language instructions",
|
|
785
|
+
)
|
|
786
|
+
task_args.add_argument(
|
|
787
|
+
"--tasks-path",
|
|
788
|
+
type=Path,
|
|
789
|
+
help="The path to a .json file containing one language instruction for each episode_index",
|
|
790
|
+
)
|
|
791
|
+
parser.add_argument(
|
|
792
|
+
"--robot",
|
|
793
|
+
type=str,
|
|
794
|
+
default=None,
|
|
795
|
+
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
|
|
796
|
+
)
|
|
797
|
+
parser.add_argument(
|
|
798
|
+
"--local-dir",
|
|
799
|
+
type=Path,
|
|
800
|
+
default=None,
|
|
801
|
+
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
|
|
802
|
+
)
|
|
803
|
+
parser.add_argument(
|
|
804
|
+
"--license",
|
|
805
|
+
type=str,
|
|
806
|
+
default="apache-2.0",
|
|
807
|
+
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
|
808
|
+
)
|
|
809
|
+
parser.add_argument(
|
|
810
|
+
"--test-branch",
|
|
811
|
+
type=str,
|
|
812
|
+
default=None,
|
|
813
|
+
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
args = parser.parse_args()
|
|
817
|
+
if not args.local_dir:
|
|
818
|
+
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
|
819
|
+
|
|
820
|
+
if args.robot is not None:
|
|
821
|
+
robot_config = make_robot_config(args.robot)
|
|
822
|
+
|
|
823
|
+
del args.robot
|
|
824
|
+
|
|
825
|
+
convert_dataset(**vars(args), robot_config=robot_config)
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
if __name__ == "__main__":
|
|
829
|
+
main()
|