opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,829 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
20
+ 2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
21
+ for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
22
+
23
+ We support 3 different scenarios for these tasks (see instructions below):
24
+ 1. Single task dataset: all episodes of your dataset have the same single task.
25
+ 2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
26
+ one episode to the next.
27
+ 3. Multi task episodes: episodes of your dataset may each contain several different tasks.
28
+
29
+
30
+ Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
31
+ '--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
32
+ recorded with. For now, only Aloha/Koch type robots are supported with this option.
33
+
34
+
35
+ # 1. Single task dataset
36
+ If your dataset contains a single task, you can simply provide it directly via the CLI with the
37
+ '--single-task' option.
38
+
39
+ Examples:
40
+
41
+ ```bash
42
+ python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \
43
+ --repo-id lerobot/aloha_sim_insertion_human_image \
44
+ --single-task "Insert the peg into the socket." \
45
+ --robot-config lerobot/configs/robot/aloha.yaml \
46
+ --local-dir data
47
+ ```
48
+
49
+ ```bash
50
+ python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \
51
+ --repo-id aliberts/koch_tutorial \
52
+ --single-task "Pick the Lego block and drop it in the box on the right." \
53
+ --robot-config lerobot/configs/robot/koch.yaml \
54
+ --local-dir data
55
+ ```
56
+
57
+
58
+ # 2. Single task episodes
59
+ If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
60
+
61
+ - If your dataset already contains a language instruction column in its parquet file, you can simply provide
62
+ this column's name with the '--tasks-col' arg.
63
+
64
+ Example:
65
+
66
+ ```bash
67
+ python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \
68
+ --repo-id lerobot/stanford_kuka_multimodal_dataset \
69
+ --tasks-col "language_instruction" \
70
+ --local-dir data
71
+ ```
72
+
73
+ - If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
74
+ '--tasks-path' arg. This file should have the following structure where keys correspond to each
75
+ episode_index in the dataset, and values are the language instruction for that episode.
76
+
77
+ Example:
78
+
79
+ ```json
80
+ {
81
+ "0": "Do something",
82
+ "1": "Do something else",
83
+ "2": "Do something",
84
+ "3": "Go there",
85
+ ...
86
+ }
87
+ ```
88
+
89
+ # 3. Multi task episodes
90
+ If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
91
+ parquet file, and you must provide this column's name with the '--tasks-col' arg.
92
+
93
+ Example:
94
+
95
+ ```bash
96
+ python src/opentau/datasets/v2/convert_dataset_v1_to_v2.py \
97
+ --repo-id lerobot/stanford_kuka_multimodal_dataset \
98
+ --tasks-col "language_instruction" \
99
+ --local-dir data
100
+ ```
101
+ """
102
+
103
+ import argparse
104
+ import contextlib
105
+ import filecmp
106
+ import json
107
+ import logging
108
+ import math
109
+ import shutil
110
+ import subprocess
111
+ import tempfile
112
+ from pathlib import Path
113
+
114
+ import datasets
115
+ import pyarrow.compute as pc
116
+ import pyarrow.parquet as pq
117
+ import torch
118
+ from datasets import Dataset
119
+ from huggingface_hub import HfApi
120
+ from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
121
+ from safetensors.torch import load_file
122
+
123
+ from opentau.datasets.utils import (
124
+ DEFAULT_CHUNK_SIZE,
125
+ DEFAULT_PARQUET_PATH,
126
+ DEFAULT_VIDEO_PATH,
127
+ EPISODES_PATH,
128
+ INFO_PATH,
129
+ STATS_PATH,
130
+ TASKS_PATH,
131
+ create_branch,
132
+ create_lerobot_dataset_card,
133
+ flatten_dict,
134
+ get_safe_version,
135
+ load_json,
136
+ unflatten_dict,
137
+ write_json,
138
+ write_jsonlines,
139
+ )
140
+ from opentau.datasets.video_utils import (
141
+ VideoFrame, # noqa: F401
142
+ get_image_pixel_channels,
143
+ get_video_info,
144
+ )
145
+ from opentau.robot_devices.robots.configs import RobotConfig
146
+ from opentau.robot_devices.robots.utils import make_robot_config
147
+
148
+ V16 = "v1.6"
149
+ V20 = "v2.0"
150
+
151
+ GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
152
+ V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
153
+ V1_INFO_PATH = "meta_data/info.json"
154
+ V1_STATS_PATH = "meta_data/stats.safetensors"
155
+
156
+
157
+ def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
158
+ """Parse robot configuration to extract robot type and motor names.
159
+
160
+ Extracts state and action motor names from the robot configuration.
161
+ Currently supports "aloha" and "koch" robot types.
162
+
163
+ Args:
164
+ robot_cfg: Robot configuration object.
165
+
166
+ Returns:
167
+ Dictionary with 'robot_type' and 'names' keys. The 'names' dictionary
168
+ maps feature keys to lists of motor names.
169
+
170
+ Raises:
171
+ NotImplementedError: If robot type is not supported.
172
+ """
173
+ if robot_cfg.type in ["aloha", "koch"]:
174
+ state_names = [
175
+ f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
176
+ for arm in robot_cfg.follower_arms
177
+ for motor in robot_cfg.follower_arms[arm].motors
178
+ ]
179
+ action_names = [
180
+ # f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
181
+ f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
182
+ for arm in robot_cfg.leader_arms
183
+ for motor in robot_cfg.leader_arms[arm].motors
184
+ ]
185
+ # elif robot_cfg["robot_type"] == "stretch3": TODO
186
+ else:
187
+ raise NotImplementedError(
188
+ "Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
189
+ )
190
+
191
+ return {
192
+ "robot_type": robot_cfg.type,
193
+ "names": {
194
+ "observation.state": state_names,
195
+ "observation.effort": state_names,
196
+ "action": action_names,
197
+ },
198
+ }
199
+
200
+
201
+ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
202
+ """Convert statistics from v1.6 safetensors format to v2.0 JSON format.
203
+
204
+ Loads statistics from safetensors file, converts to JSON, and validates
205
+ that the conversion preserves all values correctly.
206
+
207
+ Args:
208
+ v1_dir: Directory containing v1.6 dataset with meta_data/stats.safetensors.
209
+ v2_dir: Directory where v2.0 dataset meta/stats.json will be written.
210
+
211
+ Raises:
212
+ AssertionError: If converted statistics don't match original values.
213
+ """
214
+ safetensor_path = v1_dir / V1_STATS_PATH
215
+ stats = load_file(safetensor_path)
216
+ serialized_stats = {key: value.tolist() for key, value in stats.items()}
217
+ serialized_stats = unflatten_dict(serialized_stats)
218
+
219
+ json_path = v2_dir / STATS_PATH
220
+ json_path.parent.mkdir(exist_ok=True, parents=True)
221
+ with open(json_path, "w") as f:
222
+ json.dump(serialized_stats, f, indent=4)
223
+
224
+ # Sanity check
225
+ with open(json_path) as f:
226
+ stats_json = json.load(f)
227
+
228
+ stats_json = flatten_dict(stats_json)
229
+ stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
230
+ for key in stats:
231
+ torch.testing.assert_close(stats_json[key], stats[key])
232
+
233
+
234
+ def get_features_from_hf_dataset(
235
+ dataset: Dataset, robot_config: RobotConfig | None = None
236
+ ) -> dict[str, list]:
237
+ """Extract feature specifications from a HuggingFace dataset.
238
+
239
+ Converts HuggingFace dataset features to the format expected by v2.0 datasets.
240
+ Handles Value, Sequence, Image, and VideoFrame feature types.
241
+
242
+ Args:
243
+ dataset: HuggingFace dataset to extract features from.
244
+ robot_config: Optional robot configuration for motor name extraction.
245
+
246
+ Returns:
247
+ Dictionary mapping feature names to feature specifications with
248
+ 'dtype', 'shape', and 'names' keys.
249
+ """
250
+ robot_config = parse_robot_config(robot_config)
251
+ features = {}
252
+ for key, ft in dataset.features.items():
253
+ if isinstance(ft, datasets.Value):
254
+ dtype = ft.dtype
255
+ shape = (1,)
256
+ names = None
257
+ if isinstance(ft, datasets.Sequence):
258
+ assert isinstance(ft.feature, datasets.Value)
259
+ dtype = ft.feature.dtype
260
+ shape = (ft.length,)
261
+ motor_names = (
262
+ robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
263
+ )
264
+ assert len(motor_names) == shape[0]
265
+ names = {"motors": motor_names}
266
+ elif isinstance(ft, datasets.Image):
267
+ dtype = "image"
268
+ image = dataset[0][key] # Assuming first row
269
+ channels = get_image_pixel_channels(image)
270
+ shape = (image.height, image.width, channels)
271
+ names = ["height", "width", "channels"]
272
+ elif ft._type == "VideoFrame":
273
+ dtype = "video"
274
+ shape = None # Add shape later
275
+ names = ["height", "width", "channels"]
276
+
277
+ features[key] = {
278
+ "dtype": dtype,
279
+ "shape": shape,
280
+ "names": names,
281
+ }
282
+
283
+ return features
284
+
285
+
286
+ def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
287
+ """Add task_index column to dataset based on episode-to-task mapping.
288
+
289
+ Creates a task_index column by mapping episode indices to task indices.
290
+ Also creates a tasks list mapping task_index to task description.
291
+
292
+ Args:
293
+ dataset: HuggingFace dataset to modify.
294
+ tasks_by_episodes: Dictionary mapping episode_index to task description.
295
+
296
+ Returns:
297
+ Tuple of (modified_dataset, tasks_list) where tasks_list maps
298
+ task_index to task description.
299
+ """
300
+ df = dataset.to_pandas()
301
+ tasks = list(set(tasks_by_episodes.values()))
302
+ tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
303
+ episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
304
+ df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
305
+
306
+ features = dataset.features
307
+ features["task_index"] = datasets.Value(dtype="int64")
308
+ dataset = Dataset.from_pandas(df, features=features, split="train")
309
+ return dataset, tasks
310
+
311
+
312
+ def add_task_index_from_tasks_col(
313
+ dataset: Dataset, tasks_col: str
314
+ ) -> tuple[Dataset, dict[str, list[str]], list[str]]:
315
+ """Add task_index column from an existing tasks column in the dataset.
316
+
317
+ Extracts tasks from the specified column, creates task indices, and removes
318
+ the original tasks column. Also handles cleaning of tensor string formats.
319
+
320
+ Args:
321
+ dataset: HuggingFace dataset containing a tasks column.
322
+ tasks_col: Name of the column containing task descriptions.
323
+
324
+ Returns:
325
+ Tuple of (modified_dataset, tasks_list, tasks_by_episode):
326
+ - modified_dataset: Dataset with task_index column added and tasks_col removed.
327
+ - tasks_list: List of unique tasks.
328
+ - tasks_by_episode: Dictionary mapping episode_index to list of tasks.
329
+ """
330
+ df = dataset.to_pandas()
331
+
332
+ # HACK: This is to clean some of the instructions in our version of Open X datasets
333
+ prefix_to_clean = "tf.Tensor(b'"
334
+ suffix_to_clean = "', shape=(), dtype=string)"
335
+ df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
336
+
337
+ # Create task_index col
338
+ tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
339
+ tasks = df[tasks_col].unique().tolist()
340
+ tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
341
+ df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
342
+
343
+ # Build the dataset back from df
344
+ features = dataset.features
345
+ features["task_index"] = datasets.Value(dtype="int64")
346
+ dataset = Dataset.from_pandas(df, features=features, split="train")
347
+ dataset = dataset.remove_columns(tasks_col)
348
+
349
+ return dataset, tasks, tasks_by_episode
350
+
351
+
352
+ def split_parquet_by_episodes(
353
+ dataset: Dataset,
354
+ total_episodes: int,
355
+ total_chunks: int,
356
+ output_dir: Path,
357
+ ) -> list:
358
+ """Split dataset into separate parquet files, one per episode.
359
+
360
+ Organizes episodes into chunks and writes each episode to its own parquet file
361
+ following the v2.0 directory structure.
362
+
363
+ Args:
364
+ dataset: HuggingFace dataset to split.
365
+ total_episodes: Total number of episodes in the dataset.
366
+ total_chunks: Total number of chunks to organize episodes into.
367
+ output_dir: Root directory where parquet files will be written.
368
+
369
+ Returns:
370
+ List of episode lengths (one per episode).
371
+ """
372
+ table = dataset.data.table
373
+ episode_lengths = []
374
+ for ep_chunk in range(total_chunks):
375
+ ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
376
+ ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
377
+ chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
378
+ (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
379
+ for ep_idx in range(ep_chunk_start, ep_chunk_end):
380
+ ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
381
+ episode_lengths.insert(ep_idx, len(ep_table))
382
+ output_file = output_dir / DEFAULT_PARQUET_PATH.format(
383
+ episode_chunk=ep_chunk, episode_index=ep_idx
384
+ )
385
+ pq.write_table(ep_table, output_file)
386
+
387
+ return episode_lengths
388
+
389
+
390
+ def move_videos(
391
+ repo_id: str,
392
+ video_keys: list[str],
393
+ total_episodes: int,
394
+ total_chunks: int,
395
+ work_dir: Path,
396
+ clean_gittatributes: Path,
397
+ branch: str = "main",
398
+ ) -> None:
399
+ """Move video files from v1.6 flat structure to v2.0 chunked structure.
400
+
401
+ Uses git LFS to move video file references without downloading the actual files.
402
+ This is a workaround since HfApi doesn't support moving files directly.
403
+
404
+ Args:
405
+ repo_id: Repository ID of the dataset.
406
+ video_keys: List of video feature keys to move.
407
+ total_episodes: Total number of episodes.
408
+ total_chunks: Total number of chunks.
409
+ work_dir: Working directory for git operations.
410
+ clean_gittatributes: Path to clean .gitattributes file.
411
+ branch: Git branch to work with. Defaults to "main".
412
+
413
+ Note:
414
+ This function uses git commands to manipulate LFS-tracked files without
415
+ downloading them, which is more efficient than using the HuggingFace API.
416
+ """
417
+ _lfs_clone(repo_id, work_dir, branch)
418
+
419
+ videos_moved = False
420
+ video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
421
+ if len(video_files) == 0:
422
+ video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
423
+ videos_moved = True # Videos have already been moved
424
+
425
+ assert len(video_files) == total_episodes * len(video_keys)
426
+
427
+ lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
428
+
429
+ current_gittatributes = work_dir / ".gitattributes"
430
+ if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
431
+ fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
432
+
433
+ if lfs_untracked_videos:
434
+ fix_lfs_video_files_tracking(work_dir, video_files)
435
+
436
+ if videos_moved:
437
+ return
438
+
439
+ video_dirs = sorted(work_dir.glob("videos*/"))
440
+ for ep_chunk in range(total_chunks):
441
+ ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
442
+ ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
443
+ for vid_key in video_keys:
444
+ chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
445
+ episode_chunk=ep_chunk, video_key=vid_key
446
+ )
447
+ (work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
448
+
449
+ for ep_idx in range(ep_chunk_start, ep_chunk_end):
450
+ target_path = DEFAULT_VIDEO_PATH.format(
451
+ episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
452
+ )
453
+ video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
454
+ if len(video_dirs) == 1:
455
+ video_path = video_dirs[0] / video_file
456
+ else:
457
+ for dir in video_dirs:
458
+ if (dir / video_file).is_file():
459
+ video_path = dir / video_file
460
+ break
461
+
462
+ video_path.rename(work_dir / target_path)
463
+
464
+ commit_message = "Move video files into chunk subdirectories"
465
+ subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
466
+ subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
467
+ subprocess.run(["git", "push"], cwd=work_dir, check=True)
468
+
469
+
470
+ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
471
+ """Fix Git LFS tracking for video files that aren't properly tracked.
472
+
473
+ Adds video files to .gitattributes to ensure they're tracked by Git LFS.
474
+
475
+ Args:
476
+ work_dir: Working directory containing the repository.
477
+ lfs_untracked_videos: List of video file paths that need LFS tracking.
478
+ """
479
+ """
480
+ HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
481
+ there's no other option than to download the actual files and reupload them with lfs tracking.
482
+ """
483
+ for i in range(0, len(lfs_untracked_videos), 100):
484
+ files = lfs_untracked_videos[i : i + 100]
485
+ try:
486
+ subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
487
+ except subprocess.CalledProcessError as e:
488
+ print("git rm --cached ERROR:")
489
+ print(e.stderr)
490
+ subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
491
+
492
+ commit_message = "Track video files with git lfs"
493
+ subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
494
+ subprocess.run(["git", "push"], cwd=work_dir, check=True)
495
+
496
+
497
+ def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
498
+ """Replace .gitattributes file with a clean version and commit the change.
499
+
500
+ Args:
501
+ work_dir: Working directory containing the repository.
502
+ current_gittatributes: Path to current .gitattributes file.
503
+ clean_gittatributes: Path to clean .gitattributes file to use as replacement.
504
+ """
505
+ shutil.copyfile(clean_gittatributes, current_gittatributes)
506
+ subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
507
+ subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
508
+ subprocess.run(["git", "push"], cwd=work_dir, check=True)
509
+
510
+
511
+ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
512
+ """Clone a repository with Git LFS support, skipping actual file downloads.
513
+
514
+ Initializes Git LFS and clones the repository without downloading LFS files
515
+ (using GIT_LFS_SKIP_SMUDGE=1) for efficiency.
516
+
517
+ Args:
518
+ repo_id: Repository ID to clone.
519
+ work_dir: Directory where the repository will be cloned.
520
+ branch: Git branch to checkout.
521
+ """
522
+ subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
523
+ repo_url = f"https://huggingface.co/datasets/{repo_id}"
524
+ env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
525
+ subprocess.run(
526
+ ["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
527
+ check=True,
528
+ env=env,
529
+ )
530
+
531
+
532
+ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
533
+ """Get list of video files that are not tracked by Git LFS.
534
+
535
+ Args:
536
+ work_dir: Working directory containing the repository.
537
+ video_files: List of video file paths to check.
538
+
539
+ Returns:
540
+ List of video file paths that are not tracked by Git LFS.
541
+ """
542
+ lfs_tracked_files = subprocess.run(
543
+ ["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
544
+ )
545
+ lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
546
+ return [f for f in video_files if f not in lfs_tracked_files]
547
+
548
+
549
+ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
550
+ """Get video information (codec, dimensions, etc.) from the first episode.
551
+
552
+ Downloads video files from the first episode to extract metadata.
553
+
554
+ Args:
555
+ repo_id: Repository ID of the dataset.
556
+ local_dir: Local directory where videos will be downloaded.
557
+ video_keys: List of video feature keys to get info for.
558
+ branch: Git branch to use.
559
+
560
+ Returns:
561
+ Dictionary mapping video keys to their information dictionaries.
562
+ """
563
+ # Assumes first episode
564
+ video_files = [
565
+ DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
566
+ for vid_key in video_keys
567
+ ]
568
+ hub_api = HfApi()
569
+ hub_api.snapshot_download(
570
+ repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
571
+ )
572
+ videos_info_dict = {}
573
+ for vid_key, vid_path in zip(video_keys, video_files, strict=True):
574
+ videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
575
+
576
+ return videos_info_dict
577
+
578
+
579
+ def convert_dataset(
580
+ repo_id: str,
581
+ local_dir: Path,
582
+ single_task: str | None = None,
583
+ tasks_path: Path | None = None,
584
+ tasks_col: Path | None = None,
585
+ robot_config: RobotConfig | None = None,
586
+ test_branch: str | None = None,
587
+ **card_kwargs,
588
+ ) -> None:
589
+ """Convert a dataset from v1.6 format to v2.0 format.
590
+
591
+ Handles conversion of metadata, statistics, parquet files, videos, and tasks.
592
+ Supports three scenarios: single task dataset, single task per episode, or
593
+ multiple tasks per episode.
594
+
595
+ Args:
596
+ repo_id: Repository ID of the dataset to convert.
597
+ local_dir: Local directory for downloading and writing converted dataset.
598
+ single_task: Single task description for datasets with one task.
599
+ Mutually exclusive with tasks_path and tasks_col.
600
+ tasks_path: Path to JSON file mapping episode_index to task descriptions.
601
+ Mutually exclusive with single_task and tasks_col.
602
+ tasks_col: Name of column in parquet files containing task descriptions.
603
+ Mutually exclusive with single_task and tasks_path.
604
+ robot_config: Optional robot configuration for extracting motor names.
605
+ test_branch: Optional branch name for testing conversion without affecting main.
606
+ **card_kwargs: Additional keyword arguments for dataset card creation.
607
+
608
+ Raises:
609
+ ValueError: If task specification arguments are invalid or missing.
610
+ """
611
+ v1 = get_safe_version(repo_id, V16)
612
+ v1x_dir = local_dir / V16 / repo_id
613
+ v20_dir = local_dir / V20 / repo_id
614
+ v1x_dir.mkdir(parents=True, exist_ok=True)
615
+ v20_dir.mkdir(parents=True, exist_ok=True)
616
+
617
+ hub_api = HfApi()
618
+ hub_api.snapshot_download(
619
+ repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
620
+ )
621
+ branch = "main"
622
+ if test_branch:
623
+ branch = test_branch
624
+ create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
625
+
626
+ metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
627
+ dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
628
+ features = get_features_from_hf_dataset(dataset, robot_config)
629
+ video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
630
+
631
+ if single_task and "language_instruction" in dataset.column_names:
632
+ logging.warning(
633
+ "'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
634
+ )
635
+ single_task = None
636
+ tasks_col = "language_instruction"
637
+
638
+ # Episodes & chunks
639
+ episode_indices = sorted(dataset.unique("episode_index"))
640
+ total_episodes = len(episode_indices)
641
+ assert episode_indices == list(range(total_episodes))
642
+ total_videos = total_episodes * len(video_keys)
643
+ total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
644
+ if total_episodes % DEFAULT_CHUNK_SIZE != 0:
645
+ total_chunks += 1
646
+
647
+ # Tasks
648
+ if single_task:
649
+ tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
650
+ dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
651
+ tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
652
+ elif tasks_path:
653
+ tasks_by_episodes = load_json(tasks_path)
654
+ tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
655
+ dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
656
+ tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
657
+ elif tasks_col:
658
+ dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
659
+ else:
660
+ raise ValueError
661
+
662
+ assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
663
+ tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
664
+ write_jsonlines(tasks, v20_dir / TASKS_PATH)
665
+ features["task_index"] = {
666
+ "dtype": "int64",
667
+ "shape": (1,),
668
+ "names": None,
669
+ }
670
+
671
+ # Videos
672
+ if video_keys:
673
+ assert metadata_v1.get("video", False)
674
+ dataset = dataset.remove_columns(video_keys)
675
+ clean_gitattr = Path(
676
+ hub_api.hf_hub_download(
677
+ repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
678
+ )
679
+ ).absolute()
680
+ with tempfile.TemporaryDirectory() as tmp_video_dir:
681
+ move_videos(
682
+ repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
683
+ )
684
+ videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
685
+ for key in video_keys:
686
+ features[key]["shape"] = (
687
+ videos_info[key].pop("video.height"),
688
+ videos_info[key].pop("video.width"),
689
+ videos_info[key].pop("video.channels"),
690
+ )
691
+ features[key]["video_info"] = videos_info[key]
692
+ assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
693
+ if "encoding" in metadata_v1:
694
+ assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
695
+ else:
696
+ assert metadata_v1.get("video", 0) == 0
697
+ videos_info = None
698
+
699
+ # Split data into 1 parquet file by episode
700
+ episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
701
+
702
+ if robot_config is not None:
703
+ robot_type = robot_config.type
704
+ repo_tags = [robot_type]
705
+ else:
706
+ robot_type = "unknown"
707
+ repo_tags = None
708
+
709
+ # Episodes
710
+ episodes = [
711
+ {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
712
+ for ep_idx in episode_indices
713
+ ]
714
+ write_jsonlines(episodes, v20_dir / EPISODES_PATH)
715
+
716
+ # Assemble metadata v2.0
717
+ metadata_v2_0 = {
718
+ "codebase_version": V20,
719
+ "robot_type": robot_type,
720
+ "total_episodes": total_episodes,
721
+ "total_frames": len(dataset),
722
+ "total_tasks": len(tasks),
723
+ "total_videos": total_videos,
724
+ "total_chunks": total_chunks,
725
+ "chunks_size": DEFAULT_CHUNK_SIZE,
726
+ "fps": metadata_v1["fps"],
727
+ "splits": {"train": f"0:{total_episodes}"},
728
+ "data_path": DEFAULT_PARQUET_PATH,
729
+ "video_path": DEFAULT_VIDEO_PATH if video_keys else None,
730
+ "features": features,
731
+ }
732
+ write_json(metadata_v2_0, v20_dir / INFO_PATH)
733
+ convert_stats_to_json(v1x_dir, v20_dir)
734
+ card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
735
+
736
+ with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
737
+ hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
738
+
739
+ with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
740
+ hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
741
+
742
+ with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
743
+ hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
744
+
745
+ hub_api.upload_folder(
746
+ repo_id=repo_id,
747
+ path_in_repo="data",
748
+ folder_path=v20_dir / "data",
749
+ repo_type="dataset",
750
+ revision=branch,
751
+ )
752
+ hub_api.upload_folder(
753
+ repo_id=repo_id,
754
+ path_in_repo="meta",
755
+ folder_path=v20_dir / "meta",
756
+ repo_type="dataset",
757
+ revision=branch,
758
+ )
759
+
760
+ card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
761
+
762
+ if not test_branch:
763
+ create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
764
+
765
+
766
+ def main():
767
+ parser = argparse.ArgumentParser()
768
+ task_args = parser.add_mutually_exclusive_group(required=True)
769
+
770
+ parser.add_argument(
771
+ "--repo-id",
772
+ type=str,
773
+ required=True,
774
+ help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
775
+ )
776
+ task_args.add_argument(
777
+ "--single-task",
778
+ type=str,
779
+ help="A short but accurate description of the single task performed in the dataset.",
780
+ )
781
+ task_args.add_argument(
782
+ "--tasks-col",
783
+ type=str,
784
+ help="The name of the column containing language instructions",
785
+ )
786
+ task_args.add_argument(
787
+ "--tasks-path",
788
+ type=Path,
789
+ help="The path to a .json file containing one language instruction for each episode_index",
790
+ )
791
+ parser.add_argument(
792
+ "--robot",
793
+ type=str,
794
+ default=None,
795
+ help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
796
+ )
797
+ parser.add_argument(
798
+ "--local-dir",
799
+ type=Path,
800
+ default=None,
801
+ help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
802
+ )
803
+ parser.add_argument(
804
+ "--license",
805
+ type=str,
806
+ default="apache-2.0",
807
+ help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
808
+ )
809
+ parser.add_argument(
810
+ "--test-branch",
811
+ type=str,
812
+ default=None,
813
+ help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
814
+ )
815
+
816
+ args = parser.parse_args()
817
+ if not args.local_dir:
818
+ args.local_dir = Path("/tmp/lerobot_dataset_v2")
819
+
820
+ if args.robot is not None:
821
+ robot_config = make_robot_config(args.robot)
822
+
823
+ del args.robot
824
+
825
+ convert_dataset(**vars(args), robot_config=robot_config)
826
+
827
+
828
+ if __name__ == "__main__":
829
+ main()