opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1910 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""LeRobot dataset implementation for robot learning data management.
|
|
18
|
+
|
|
19
|
+
This module provides the core dataset implementation for loading, creating, and
|
|
20
|
+
managing robot learning datasets. It supports both loading existing datasets from
|
|
21
|
+
the HuggingFace Hub or local disk, as well as creating new datasets for data
|
|
22
|
+
recording.
|
|
23
|
+
|
|
24
|
+
The dataset structure consists of:
|
|
25
|
+
|
|
26
|
+
- Metadata: Info, statistics, tasks, and episode information stored as JSON
|
|
27
|
+
- Data files: Episode data stored as Parquet files organized by chunks
|
|
28
|
+
- Videos: Optional video files for camera observations stored as MP4 files
|
|
29
|
+
|
|
30
|
+
Key Features:
|
|
31
|
+
|
|
32
|
+
- Temporal alignment: Supports delta timestamps for temporal feature
|
|
33
|
+
alignment, enabling sampling of features at different time offsets with
|
|
34
|
+
optional Gaussian noise for data augmentation.
|
|
35
|
+
- Multi-modal support: Handles images, videos, state vectors, actions, and
|
|
36
|
+
text prompts with automatic format conversion and standardization.
|
|
37
|
+
- Version compatibility: Automatic version checking and backward compatibility
|
|
38
|
+
handling for datasets created with older format versions.
|
|
39
|
+
- Asynchronous image writing: Optional async image writer for high-frequency
|
|
40
|
+
data recording without blocking the main process.
|
|
41
|
+
- Statistics management: Per-episode and aggregated statistics for data
|
|
42
|
+
normalization, with automatic computation and aggregation.
|
|
43
|
+
- Video handling: Supports multiple video backends (torchcodec, pyav,
|
|
44
|
+
video_reader) for efficient video encoding and decoding.
|
|
45
|
+
|
|
46
|
+
Classes:
|
|
47
|
+
|
|
48
|
+
DatasetMetadata
|
|
49
|
+
Base class for dataset metadata management.
|
|
50
|
+
|
|
51
|
+
LeRobotDatasetMetadata
|
|
52
|
+
Metadata manager for LeRobot datasets with Hub integration, version
|
|
53
|
+
checking, and statistics loading.
|
|
54
|
+
|
|
55
|
+
GroundingDatasetMetadata
|
|
56
|
+
Metadata manager for grounding datasets.
|
|
57
|
+
|
|
58
|
+
BaseDataset
|
|
59
|
+
Base PyTorch Dataset class with common functionality.
|
|
60
|
+
|
|
61
|
+
LeRobotDataset
|
|
62
|
+
Main dataset class for robot learning data, supporting loading from
|
|
63
|
+
Hub/local disk, temporal alignment, video/image handling, and data
|
|
64
|
+
recording.
|
|
65
|
+
|
|
66
|
+
Functions:
|
|
67
|
+
retry_random_on_failure
|
|
68
|
+
Decorator to retry dataset item retrieval with random indices on failure.
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
Load an existing dataset:
|
|
72
|
+
>>> dataset = LeRobotDataset(cfg, repo_id="my-robot-dataset")
|
|
73
|
+
>>> dataloader = DataLoader(dataset, batch_size=32)
|
|
74
|
+
|
|
75
|
+
Create a new dataset for recording:
|
|
76
|
+
>>> dataset = LeRobotDataset.create(
|
|
77
|
+
... repo_id="my-new-dataset",
|
|
78
|
+
... fps=30,
|
|
79
|
+
... features={"state": {"shape": (7,), "dtype": "float32"}},
|
|
80
|
+
... use_videos=True
|
|
81
|
+
... )
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
import contextlib
|
|
85
|
+
import functools
|
|
86
|
+
import logging
|
|
87
|
+
import math
|
|
88
|
+
import shutil
|
|
89
|
+
import traceback
|
|
90
|
+
from abc import abstractmethod
|
|
91
|
+
from pathlib import Path
|
|
92
|
+
from typing import Callable
|
|
93
|
+
|
|
94
|
+
import datasets
|
|
95
|
+
import numpy as np
|
|
96
|
+
import packaging.version
|
|
97
|
+
import PIL.Image
|
|
98
|
+
import torch
|
|
99
|
+
import torch.nn.functional as F # noqa: N812
|
|
100
|
+
import torch.utils
|
|
101
|
+
from datasets import concatenate_datasets, load_dataset
|
|
102
|
+
from einops import rearrange
|
|
103
|
+
from huggingface_hub import HfApi, snapshot_download
|
|
104
|
+
from huggingface_hub.constants import REPOCARD_NAME
|
|
105
|
+
from huggingface_hub.errors import RevisionNotFoundError
|
|
106
|
+
|
|
107
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
108
|
+
from opentau.constants import HF_OPENTAU_HOME
|
|
109
|
+
from opentau.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
|
110
|
+
from opentau.datasets.image_writer import AsyncImageWriter, write_image
|
|
111
|
+
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING, LOSS_TYPE_MAPPING
|
|
112
|
+
from opentau.datasets.utils import (
|
|
113
|
+
DEFAULT_FEATURES,
|
|
114
|
+
DEFAULT_IMAGE_PATH,
|
|
115
|
+
INFO_PATH,
|
|
116
|
+
TASKS_PATH,
|
|
117
|
+
append_jsonlines,
|
|
118
|
+
backward_compatible_episodes_stats,
|
|
119
|
+
check_timestamps_sync,
|
|
120
|
+
check_version_compatibility,
|
|
121
|
+
create_empty_dataset_info,
|
|
122
|
+
create_lerobot_dataset_card,
|
|
123
|
+
embed_images,
|
|
124
|
+
get_delta_indices_soft,
|
|
125
|
+
get_episode_data_index,
|
|
126
|
+
get_hf_features_from_features,
|
|
127
|
+
get_safe_version,
|
|
128
|
+
hf_transform_to_torch,
|
|
129
|
+
is_valid_version,
|
|
130
|
+
load_advantages,
|
|
131
|
+
load_episodes,
|
|
132
|
+
load_episodes_stats,
|
|
133
|
+
load_info,
|
|
134
|
+
load_stats,
|
|
135
|
+
load_tasks,
|
|
136
|
+
validate_episode_buffer,
|
|
137
|
+
validate_frame,
|
|
138
|
+
write_episode,
|
|
139
|
+
write_episode_stats,
|
|
140
|
+
write_info,
|
|
141
|
+
write_json,
|
|
142
|
+
)
|
|
143
|
+
from opentau.datasets.video_utils import (
|
|
144
|
+
decode_video_frames,
|
|
145
|
+
encode_video_frames,
|
|
146
|
+
get_safe_default_codec,
|
|
147
|
+
get_video_info,
|
|
148
|
+
)
|
|
149
|
+
from opentau.policies.value.configuration_value import ValueConfig
|
|
150
|
+
from opentau.policies.value.reward import (
|
|
151
|
+
calculate_return_bins_with_equal_width,
|
|
152
|
+
)
|
|
153
|
+
from opentau.utils.utils import on_accelerate_main_proc
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def retry_random_on_failure(f):
|
|
157
|
+
"""Decorator to retry dataset item retrieval with random indices on failure.
|
|
158
|
+
|
|
159
|
+
When a dataset item fails to load, this decorator will retry with random
|
|
160
|
+
indices up to `_total_rand_attempts` times before raising an error.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
f: The `__getitem__` method to wrap.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Wrapped function that retries on failure.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
@functools.wraps(f)
|
|
170
|
+
def wrapped(self, idx):
|
|
171
|
+
g = getattr(self, "_rr_rng", None)
|
|
172
|
+
total_attempts = getattr(self, "_total_rand_attempts", 0)
|
|
173
|
+
if g is None:
|
|
174
|
+
g = torch.Generator()
|
|
175
|
+
g.manual_seed(torch.initial_seed()) # different seed per DataLoader worker
|
|
176
|
+
self._rr_rng = g
|
|
177
|
+
|
|
178
|
+
n = len(self)
|
|
179
|
+
cur = idx
|
|
180
|
+
exceptions = []
|
|
181
|
+
indices_tried = []
|
|
182
|
+
for _ in range(total_attempts + 1):
|
|
183
|
+
try:
|
|
184
|
+
indices_tried.append(cur)
|
|
185
|
+
return f(self, cur)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
print(f"Encountered failure to load data at index {cur}; retrying with a different index.")
|
|
188
|
+
cur = int(torch.randint(0, n, (1,), generator=g))
|
|
189
|
+
exceptions.append(e)
|
|
190
|
+
|
|
191
|
+
tb_strings = [
|
|
192
|
+
f"Attempt {i}: trying to fetch index {item} ...\n"
|
|
193
|
+
+ "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
194
|
+
for i, (e, item) in enumerate(zip(exceptions, indices_tried, strict=False))
|
|
195
|
+
]
|
|
196
|
+
tb_blob = "\n".join(tb_strings)
|
|
197
|
+
raise RuntimeError(
|
|
198
|
+
f"Failed to load data after {total_attempts + 1} attempt(s). "
|
|
199
|
+
"Check the following traceback for each attempts made.\n\n"
|
|
200
|
+
f"{tb_blob}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return wrapped
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
CODEBASE_VERSION = "v2.1"
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class DatasetMetadata:
|
|
210
|
+
"""Base class for dataset metadata containing info and statistics.
|
|
211
|
+
|
|
212
|
+
Attributes:
|
|
213
|
+
info: Dictionary containing dataset information (features, fps, etc.).
|
|
214
|
+
stats: Dictionary containing dataset statistics for normalization.
|
|
215
|
+
repo_id: Repository ID of the dataset (set by subclasses).
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(self, *, info: dict = None, stats: dict = None):
|
|
219
|
+
self.info = info or {"features": {}}
|
|
220
|
+
self.stats = stats or {}
|
|
221
|
+
|
|
222
|
+
for feature_name in self.stats:
|
|
223
|
+
for metric in self.stats[feature_name]:
|
|
224
|
+
if isinstance(self.stats[feature_name][metric], (list, tuple)):
|
|
225
|
+
self.stats[feature_name][metric] = np.array(self.stats[feature_name][metric])
|
|
226
|
+
# TODO: check stats[feature_name][metric].shape is broadcastable with features[feature_name]["shape"]
|
|
227
|
+
|
|
228
|
+
self.repo_id = None
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def features(self) -> dict[str, dict]:
|
|
232
|
+
"""All features contained in the dataset."""
|
|
233
|
+
return self.info["features"]
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def image_keys(self) -> list[str]:
|
|
237
|
+
"""Keys to access visual modalities stored as images."""
|
|
238
|
+
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def video_keys(self) -> list[str]:
|
|
242
|
+
"""Keys to access visual modalities stored as videos."""
|
|
243
|
+
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
def camera_keys(self) -> list[str]:
|
|
247
|
+
"""Keys to access visual modalities (regardless of their storage method)."""
|
|
248
|
+
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def names(self) -> dict[str, list | dict]:
|
|
252
|
+
"""Names of the various dimensions of vector modalities."""
|
|
253
|
+
return {key: ft["names"] for key, ft in self.features.items()}
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def shapes(self) -> dict:
|
|
257
|
+
"""Shapes for the different features."""
|
|
258
|
+
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class GroundingDatasetMetadata(DatasetMetadata):
|
|
262
|
+
"""Metadata class for grounding datasets (vision-language datasets)."""
|
|
263
|
+
|
|
264
|
+
pass
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class LeRobotDatasetMetadata(DatasetMetadata):
|
|
268
|
+
"""Metadata manager for LeRobot datasets with Hub integration and version handling.
|
|
269
|
+
|
|
270
|
+
This class manages all metadata for LeRobot datasets, including dataset info,
|
|
271
|
+
statistics, episodes, tasks, and advantages. It handles loading from local disk
|
|
272
|
+
or HuggingFace Hub, version compatibility checking, and provides utilities for
|
|
273
|
+
accessing dataset files and information.
|
|
274
|
+
|
|
275
|
+
The class automatically handles:
|
|
276
|
+
- Loading metadata from local disk or downloading from HuggingFace Hub
|
|
277
|
+
- Version compatibility checking and automatic version resolution
|
|
278
|
+
- Backward compatibility with older dataset formats (v2.0 vs v2.1)
|
|
279
|
+
- Episode and task management
|
|
280
|
+
- Statistics aggregation (per-episode and global)
|
|
281
|
+
|
|
282
|
+
Attributes:
|
|
283
|
+
repo_id: Repository ID of the dataset on HuggingFace Hub.
|
|
284
|
+
root: Local root directory where the dataset is stored.
|
|
285
|
+
revision: Git revision (branch/tag/commit) of the dataset.
|
|
286
|
+
info: Dictionary containing dataset information (features, fps, paths, etc.).
|
|
287
|
+
stats: Aggregated statistics dictionary (mean, std, min, max, count).
|
|
288
|
+
episodes_stats: Per-episode statistics dictionary.
|
|
289
|
+
episodes: Dictionary mapping episode_index to episode information.
|
|
290
|
+
tasks: Dictionary mapping task_index to task descriptions.
|
|
291
|
+
task_to_task_index: Reverse mapping from task description to task_index.
|
|
292
|
+
advantages: Dictionary mapping (episode_index, timestamp) to advantage values.
|
|
293
|
+
|
|
294
|
+
Example:
|
|
295
|
+
Load metadata from Hub:
|
|
296
|
+
>>> meta = LeRobotDatasetMetadata("lerobot/aloha_mobile_cabinet")
|
|
297
|
+
>>> print(f"Total episodes: {meta.total_episodes}")
|
|
298
|
+
|
|
299
|
+
Create new dataset metadata:
|
|
300
|
+
>>> meta = LeRobotDatasetMetadata.create(
|
|
301
|
+
... repo_id="my-dataset",
|
|
302
|
+
... fps=30,
|
|
303
|
+
... features={"state": {"dtype": "float32", "shape": (7,)}}
|
|
304
|
+
... )
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
repo_id: str,
|
|
310
|
+
root: str | Path | None = None,
|
|
311
|
+
revision: str | None = None,
|
|
312
|
+
force_cache_sync: bool = False,
|
|
313
|
+
):
|
|
314
|
+
super().__init__()
|
|
315
|
+
self.repo_id = repo_id
|
|
316
|
+
self.revision = revision if revision else CODEBASE_VERSION
|
|
317
|
+
self.root = Path(root) if root is not None else HF_OPENTAU_HOME / repo_id
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
if force_cache_sync:
|
|
321
|
+
raise FileNotFoundError
|
|
322
|
+
self.load_metadata()
|
|
323
|
+
except (FileNotFoundError, NotADirectoryError):
|
|
324
|
+
if is_valid_version(self.revision):
|
|
325
|
+
self.revision = get_safe_version(self.repo_id, self.revision)
|
|
326
|
+
|
|
327
|
+
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
|
328
|
+
self.pull_from_repo(allow_patterns="meta/")
|
|
329
|
+
self.load_metadata()
|
|
330
|
+
|
|
331
|
+
def load_metadata(self) -> None:
|
|
332
|
+
"""Load dataset metadata from disk.
|
|
333
|
+
|
|
334
|
+
Loads info, tasks, episodes, statistics, and advantages from the
|
|
335
|
+
dataset root directory. Handles version compatibility checks.
|
|
336
|
+
"""
|
|
337
|
+
self.info = load_info(self.root)
|
|
338
|
+
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
|
339
|
+
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
|
340
|
+
self.episodes = load_episodes(self.root)
|
|
341
|
+
if self._version < packaging.version.parse("v2.1"):
|
|
342
|
+
self.stats = load_stats(self.root)
|
|
343
|
+
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
|
344
|
+
else:
|
|
345
|
+
self.episodes_stats = load_episodes_stats(self.root)
|
|
346
|
+
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
|
347
|
+
|
|
348
|
+
self.advantages = load_advantages(self.root)
|
|
349
|
+
|
|
350
|
+
def pull_from_repo(
|
|
351
|
+
self,
|
|
352
|
+
allow_patterns: list[str] | str | None = None,
|
|
353
|
+
ignore_patterns: list[str] | str | None = None,
|
|
354
|
+
) -> None:
|
|
355
|
+
snapshot_download(
|
|
356
|
+
self.repo_id,
|
|
357
|
+
repo_type="dataset",
|
|
358
|
+
revision=self.revision,
|
|
359
|
+
local_dir=self.root,
|
|
360
|
+
allow_patterns=allow_patterns,
|
|
361
|
+
ignore_patterns=ignore_patterns,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def _version(self) -> packaging.version.Version:
|
|
366
|
+
"""Codebase version used to create this dataset."""
|
|
367
|
+
return packaging.version.parse(self.info["codebase_version"])
|
|
368
|
+
|
|
369
|
+
def get_data_file_path(self, ep_index: int) -> Path:
|
|
370
|
+
"""Get the file path for a specific episode's parquet data file.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
ep_index: Episode index.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
Path to the parquet file for the episode.
|
|
377
|
+
"""
|
|
378
|
+
ep_chunk = self.get_episode_chunk(ep_index)
|
|
379
|
+
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
|
380
|
+
return Path(fpath)
|
|
381
|
+
|
|
382
|
+
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
|
383
|
+
"""Get the file path for a specific episode's video file.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
ep_index: Episode index.
|
|
387
|
+
vid_key: Video key/name (e.g., "camera0").
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
Path to the video file for the episode.
|
|
391
|
+
"""
|
|
392
|
+
ep_chunk = self.get_episode_chunk(ep_index)
|
|
393
|
+
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
|
394
|
+
return Path(fpath)
|
|
395
|
+
|
|
396
|
+
def get_episode_chunk(self, ep_index: int) -> int:
|
|
397
|
+
"""Get the chunk index for a given episode index.
|
|
398
|
+
|
|
399
|
+
Episodes are grouped into chunks for efficient storage.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
ep_index: Episode index.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
Chunk index containing this episode.
|
|
406
|
+
"""
|
|
407
|
+
return ep_index // self.chunks_size
|
|
408
|
+
|
|
409
|
+
@property
|
|
410
|
+
def data_path(self) -> str:
|
|
411
|
+
"""Formattable string for the parquet files."""
|
|
412
|
+
return self.info["data_path"]
|
|
413
|
+
|
|
414
|
+
@property
|
|
415
|
+
def video_path(self) -> str | None:
|
|
416
|
+
"""Formattable string for the video files."""
|
|
417
|
+
return self.info["video_path"]
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def robot_type(self) -> str | None:
|
|
421
|
+
"""Robot type used in recording this dataset."""
|
|
422
|
+
return self.info["robot_type"]
|
|
423
|
+
|
|
424
|
+
@property
|
|
425
|
+
def fps(self) -> int:
|
|
426
|
+
"""Frames per second used during data collection."""
|
|
427
|
+
return self.info["fps"]
|
|
428
|
+
|
|
429
|
+
@property
|
|
430
|
+
def total_episodes(self) -> int:
|
|
431
|
+
"""Total number of episodes available."""
|
|
432
|
+
return self.info["total_episodes"]
|
|
433
|
+
|
|
434
|
+
@property
|
|
435
|
+
def total_frames(self) -> int:
|
|
436
|
+
"""Total number of frames saved in this dataset."""
|
|
437
|
+
return self.info["total_frames"]
|
|
438
|
+
|
|
439
|
+
@property
|
|
440
|
+
def total_tasks(self) -> int:
|
|
441
|
+
"""Total number of different tasks performed in this dataset."""
|
|
442
|
+
return self.info["total_tasks"]
|
|
443
|
+
|
|
444
|
+
@property
|
|
445
|
+
def total_chunks(self) -> int:
|
|
446
|
+
"""Total number of chunks (groups of episodes)."""
|
|
447
|
+
return self.info["total_chunks"]
|
|
448
|
+
|
|
449
|
+
@property
|
|
450
|
+
def chunks_size(self) -> int:
|
|
451
|
+
"""Max number of episodes per chunk."""
|
|
452
|
+
return self.info["chunks_size"]
|
|
453
|
+
|
|
454
|
+
def get_task_index(self, task: str) -> int | None:
|
|
455
|
+
"""
|
|
456
|
+
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
|
457
|
+
otherwise return None.
|
|
458
|
+
"""
|
|
459
|
+
return self.task_to_task_index.get(task, None)
|
|
460
|
+
|
|
461
|
+
def add_task(self, task: str):
|
|
462
|
+
"""
|
|
463
|
+
Given a task in natural language, add it to the dictionary of tasks.
|
|
464
|
+
"""
|
|
465
|
+
if task in self.task_to_task_index:
|
|
466
|
+
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
|
467
|
+
|
|
468
|
+
task_index = self.info["total_tasks"]
|
|
469
|
+
self.task_to_task_index[task] = task_index
|
|
470
|
+
self.tasks[task_index] = task
|
|
471
|
+
self.info["total_tasks"] += 1
|
|
472
|
+
|
|
473
|
+
task_dict = {
|
|
474
|
+
"task_index": task_index,
|
|
475
|
+
"task": task,
|
|
476
|
+
}
|
|
477
|
+
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
|
478
|
+
|
|
479
|
+
def save_episode(
|
|
480
|
+
self,
|
|
481
|
+
episode_index: int,
|
|
482
|
+
episode_length: int,
|
|
483
|
+
episode_tasks: list[str],
|
|
484
|
+
episode_stats: dict[str, dict],
|
|
485
|
+
) -> None:
|
|
486
|
+
self.info["total_episodes"] += 1
|
|
487
|
+
self.info["total_frames"] += episode_length
|
|
488
|
+
|
|
489
|
+
chunk = self.get_episode_chunk(episode_index)
|
|
490
|
+
if chunk >= self.total_chunks:
|
|
491
|
+
self.info["total_chunks"] += 1
|
|
492
|
+
|
|
493
|
+
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
|
494
|
+
self.info["total_videos"] += len(self.video_keys)
|
|
495
|
+
if len(self.video_keys) > 0:
|
|
496
|
+
self.update_video_info()
|
|
497
|
+
|
|
498
|
+
write_info(self.info, self.root)
|
|
499
|
+
|
|
500
|
+
episode_dict = {
|
|
501
|
+
"episode_index": episode_index,
|
|
502
|
+
"tasks": episode_tasks,
|
|
503
|
+
"length": episode_length,
|
|
504
|
+
}
|
|
505
|
+
self.episodes[episode_index] = episode_dict
|
|
506
|
+
write_episode(episode_dict, self.root)
|
|
507
|
+
|
|
508
|
+
self.episodes_stats[episode_index] = episode_stats
|
|
509
|
+
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
|
510
|
+
write_episode_stats(episode_index, episode_stats, self.root)
|
|
511
|
+
|
|
512
|
+
def update_video_info(self) -> None:
|
|
513
|
+
"""
|
|
514
|
+
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
|
515
|
+
been encoded the same way. Also, this means it assumes the first episode exists.
|
|
516
|
+
"""
|
|
517
|
+
for key in self.video_keys:
|
|
518
|
+
if not self.features[key].get("info", None):
|
|
519
|
+
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
|
520
|
+
self.info["features"][key]["info"] = get_video_info(video_path)
|
|
521
|
+
|
|
522
|
+
def __repr__(self):
|
|
523
|
+
feature_keys = list(self.features)
|
|
524
|
+
return (
|
|
525
|
+
f"{self.__class__.__name__}({{\n"
|
|
526
|
+
f" Repository ID: '{self.repo_id}',\n"
|
|
527
|
+
f" Total episodes: '{self.total_episodes}',\n"
|
|
528
|
+
f" Total frames: '{self.total_frames}',\n"
|
|
529
|
+
f" Features: '{feature_keys}',\n"
|
|
530
|
+
"})',\n"
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
@classmethod
|
|
534
|
+
def create(
|
|
535
|
+
cls,
|
|
536
|
+
repo_id: str,
|
|
537
|
+
fps: int,
|
|
538
|
+
root: str | Path | None = None,
|
|
539
|
+
robot_type: str | None = None,
|
|
540
|
+
features: dict | None = None,
|
|
541
|
+
use_videos: bool = True,
|
|
542
|
+
) -> "LeRobotDatasetMetadata":
|
|
543
|
+
"""Creates metadata for a LeRobotDataset."""
|
|
544
|
+
obj = cls.__new__(cls)
|
|
545
|
+
obj.repo_id = repo_id
|
|
546
|
+
obj.root = Path(root) if root is not None else HF_OPENTAU_HOME / repo_id
|
|
547
|
+
|
|
548
|
+
obj.root.mkdir(parents=True, exist_ok=False)
|
|
549
|
+
|
|
550
|
+
if features is None:
|
|
551
|
+
raise ValueError("Dataset features must be explicitly passed upon creation.")
|
|
552
|
+
else:
|
|
553
|
+
# TODO(aliberts, rcadene): implement sanity check for features
|
|
554
|
+
features = {**features, **DEFAULT_FEATURES}
|
|
555
|
+
|
|
556
|
+
# check if none of the features contains a "/" in their names,
|
|
557
|
+
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
|
558
|
+
for key in features:
|
|
559
|
+
if "/" in key:
|
|
560
|
+
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
|
561
|
+
|
|
562
|
+
features = {**features, **DEFAULT_FEATURES}
|
|
563
|
+
|
|
564
|
+
obj.tasks, obj.task_to_task_index = {}, {}
|
|
565
|
+
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
|
566
|
+
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
|
567
|
+
if len(obj.video_keys) > 0 and not use_videos:
|
|
568
|
+
raise ValueError()
|
|
569
|
+
write_json(obj.info, obj.root / INFO_PATH)
|
|
570
|
+
obj.revision = None
|
|
571
|
+
return obj
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
class BaseDataset(torch.utils.data.Dataset):
|
|
575
|
+
"""Base class for all robot learning datasets.
|
|
576
|
+
|
|
577
|
+
This abstract base class provides common functionality for both LeRobotDataset
|
|
578
|
+
and GroundingDataset, including data format standardization, image processing,
|
|
579
|
+
and vector padding. It ensures all datasets conform to a standard format
|
|
580
|
+
regardless of their source or structure.
|
|
581
|
+
|
|
582
|
+
Key Features:
|
|
583
|
+
- Standard data format conversion: Maps dataset-specific feature names
|
|
584
|
+
to standard names (camera0, camera1, state, actions, etc.)
|
|
585
|
+
- Image standardization: Resizes and pads images to target resolution
|
|
586
|
+
while maintaining aspect ratio
|
|
587
|
+
- Vector padding: Pads state and action vectors to maximum dimensions
|
|
588
|
+
- Data type conversion: Converts floating-point tensors to bfloat16 for
|
|
589
|
+
memory efficiency
|
|
590
|
+
- String normalization: Ensures prompts and responses have consistent
|
|
591
|
+
newline formatting
|
|
592
|
+
|
|
593
|
+
Subclasses must implement:
|
|
594
|
+
- `_get_feature_mapping_key()`: Returns the key used for feature name
|
|
595
|
+
mapping (e.g., "lerobot/aloha_mobile_cabinet")
|
|
596
|
+
- `_separate_image_in_time()`: Separates temporal image sequences into
|
|
597
|
+
individual frames
|
|
598
|
+
|
|
599
|
+
Attributes:
|
|
600
|
+
resolution: Target image resolution (height, width).
|
|
601
|
+
num_cams: Number of camera views in each sample.
|
|
602
|
+
max_state_dim: Maximum dimension for state vectors.
|
|
603
|
+
max_action_dim: Maximum dimension for action vectors.
|
|
604
|
+
action_chunk: Number of actions processed in a chunk.
|
|
605
|
+
|
|
606
|
+
Example:
|
|
607
|
+
Create a custom dataset:
|
|
608
|
+
>>> class MyDataset(BaseDataset):
|
|
609
|
+
... def _get_feature_mapping_key(self):
|
|
610
|
+
... return "my-dataset"
|
|
611
|
+
... def _separate_image_in_time(self, item):
|
|
612
|
+
... pass # No temporal separation needed
|
|
613
|
+
"""
|
|
614
|
+
|
|
615
|
+
def __init__(self, cfg: TrainPipelineConfig):
|
|
616
|
+
super().__init__()
|
|
617
|
+
# Standard Data Format parameters
|
|
618
|
+
self.resolution = cfg.resolution # resolution of images (H, W) in data sample
|
|
619
|
+
self.num_cams = cfg.num_cams # number of cameras in each data sample
|
|
620
|
+
self.max_state_dim = cfg.max_state_dim # maximum dimension of the state vector
|
|
621
|
+
self.max_action_dim = cfg.max_action_dim # maximum dimension of the action vector
|
|
622
|
+
self.action_chunk = cfg.action_chunk # number of actions to be processed in a chunk
|
|
623
|
+
|
|
624
|
+
@abstractmethod
|
|
625
|
+
def _get_feature_mapping_key(self) -> str:
|
|
626
|
+
r"""Returns the key used for feature mapping"""
|
|
627
|
+
pass
|
|
628
|
+
|
|
629
|
+
@abstractmethod
|
|
630
|
+
def _separate_image_in_time(self, item: dict):
|
|
631
|
+
r"""Some keys correspond to 2 images, where the first is image at current timestamp and the second is the image
|
|
632
|
+
from some time ago. We separate these 2 images into different keys by modifying the `item` dictionary.
|
|
633
|
+
For example, {"image_key": torch.zeros(2, 3, 224, 224), "image_key_is_pad": [False, True] } will become
|
|
634
|
+
{
|
|
635
|
+
"image_key": torch.zeros(3, 224, 224),
|
|
636
|
+
"image_key_is_pad: False,
|
|
637
|
+
}.
|
|
638
|
+
"""
|
|
639
|
+
raise NotImplementedError
|
|
640
|
+
|
|
641
|
+
def _standardize_images(self, item, standard_item, n_cams, is_local) -> list[bool]:
|
|
642
|
+
"""Standardize image features to a common format.
|
|
643
|
+
|
|
644
|
+
Resizes images to the target resolution with padding, and tracks
|
|
645
|
+
which images are padded.
|
|
646
|
+
|
|
647
|
+
Args:
|
|
648
|
+
item: Input item dictionary with original image keys.
|
|
649
|
+
standard_item: Output dictionary to populate with standardized images.
|
|
650
|
+
n_cams: Number of cameras to process.
|
|
651
|
+
is_local: Whether processing local (past) images.
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
List of boolean values indicating which images are padded.
|
|
655
|
+
"""
|
|
656
|
+
name_map = DATA_FEATURES_NAME_MAPPING[self._get_feature_mapping_key()]
|
|
657
|
+
image_is_pad = []
|
|
658
|
+
for cam_idx in range(n_cams):
|
|
659
|
+
std_key = f"camera{cam_idx}"
|
|
660
|
+
key = name_map.get(std_key)
|
|
661
|
+
|
|
662
|
+
if key is None:
|
|
663
|
+
standard_item[std_key] = torch.zeros((3, *self.resolution))
|
|
664
|
+
image_is_pad.append(True)
|
|
665
|
+
else:
|
|
666
|
+
standard_item[std_key] = self.resize_with_pad(
|
|
667
|
+
item[key],
|
|
668
|
+
self.resolution[1],
|
|
669
|
+
self.resolution[0],
|
|
670
|
+
pad_value=0,
|
|
671
|
+
)
|
|
672
|
+
image_is_pad.append(item.get(key + "_is_pad", torch.tensor(False)).item())
|
|
673
|
+
assert (
|
|
674
|
+
len(standard_item[std_key].shape) == 3
|
|
675
|
+
and standard_item[std_key].shape[0] == 3
|
|
676
|
+
and standard_item[std_key].min() >= 0.0 - 1e-6 # bfloat16 results in precision loss
|
|
677
|
+
and standard_item[std_key].max() <= 1.0 + 1e-6 # bfloat16 results in precision loss
|
|
678
|
+
), (
|
|
679
|
+
f"Expected image {std_key} to have shape (3, H, W) with values in [0, 1], "
|
|
680
|
+
f"Got shape {standard_item[std_key].shape}, "
|
|
681
|
+
f"min={standard_item[std_key].min()}, "
|
|
682
|
+
f"max={standard_item[std_key].max()}, "
|
|
683
|
+
f"self={self._get_feature_mapping_key()}."
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
return image_is_pad
|
|
687
|
+
|
|
688
|
+
def _to_standard_data_format(self, item: dict) -> dict:
|
|
689
|
+
"""Convert dataset item to standard data format.
|
|
690
|
+
|
|
691
|
+
Standardizes feature names, separates images in time, pads vectors,
|
|
692
|
+
and ensures consistent data types and formats.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
item: Raw dataset item dictionary.
|
|
696
|
+
|
|
697
|
+
Returns:
|
|
698
|
+
Dictionary with standardized feature names and formats.
|
|
699
|
+
"""
|
|
700
|
+
name_map = DATA_FEATURES_NAME_MAPPING[self._get_feature_mapping_key()]
|
|
701
|
+
self._separate_image_in_time(item)
|
|
702
|
+
|
|
703
|
+
standard_item = {}
|
|
704
|
+
img_is_pad = self._standardize_images(item, standard_item, self.num_cams, False)
|
|
705
|
+
|
|
706
|
+
for new_key, key in name_map.items():
|
|
707
|
+
if new_key.startswith("camera"):
|
|
708
|
+
continue
|
|
709
|
+
standard_item[new_key] = item[key]
|
|
710
|
+
|
|
711
|
+
# pad state and action vectors
|
|
712
|
+
standard_item["state"] = self.pad_vector(standard_item["state"], self.max_state_dim)
|
|
713
|
+
standard_item["actions"] = self.pad_vector(standard_item["actions"], self.max_action_dim)
|
|
714
|
+
|
|
715
|
+
standard_item["img_is_pad"] = torch.tensor(img_is_pad, dtype=torch.bool)
|
|
716
|
+
standard_item["action_is_pad"] = item[name_map["actions"] + "_is_pad"]
|
|
717
|
+
|
|
718
|
+
# add loss type
|
|
719
|
+
standard_item["loss_type"] = LOSS_TYPE_MAPPING[self._get_feature_mapping_key()]
|
|
720
|
+
|
|
721
|
+
# cast all tensors in standard_item to bfloat16
|
|
722
|
+
for key, value in standard_item.items():
|
|
723
|
+
if isinstance(value, torch.Tensor) and value.dtype.is_floating_point:
|
|
724
|
+
standard_item[key] = value.to(dtype=torch.bfloat16)
|
|
725
|
+
|
|
726
|
+
# ensure that non-empty strings contain exactly one newline character at the end of the string
|
|
727
|
+
for key in ["prompt", "response"]:
|
|
728
|
+
if standard_item[key].endswith(
|
|
729
|
+
"\n"
|
|
730
|
+
): # ensure there isn't going to be an extra space at the end after calling replace
|
|
731
|
+
standard_item[key] = standard_item[key][:-1]
|
|
732
|
+
standard_item[key] = standard_item[key].replace("\n", " ") + "\n"
|
|
733
|
+
|
|
734
|
+
return standard_item
|
|
735
|
+
|
|
736
|
+
def resize_with_pad(self, img, width, height, pad_value=0) -> torch.Tensor:
|
|
737
|
+
"""Resize an image to target dimensions with padding.
|
|
738
|
+
|
|
739
|
+
Maintains aspect ratio by resizing to fit within target dimensions,
|
|
740
|
+
then pads on the left and top to reach exact target size.
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
img: Input image tensor of shape (C, H, W).
|
|
744
|
+
width: Target width.
|
|
745
|
+
height: Target height.
|
|
746
|
+
pad_value: Value to use for padding. Defaults to 0.
|
|
747
|
+
|
|
748
|
+
Returns:
|
|
749
|
+
Resized and padded image tensor of shape (C, height, width).
|
|
750
|
+
|
|
751
|
+
Raises:
|
|
752
|
+
ValueError: If input image doesn't have 4 dimensions when reshaped.
|
|
753
|
+
"""
|
|
754
|
+
# assume no-op when width height fits already
|
|
755
|
+
img = rearrange(img, "c h w -> 1 c h w")
|
|
756
|
+
if img.ndim != 4:
|
|
757
|
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
|
758
|
+
|
|
759
|
+
cur_height, cur_width = img.shape[2:]
|
|
760
|
+
|
|
761
|
+
ratio = max(cur_width / width, cur_height / height)
|
|
762
|
+
resized_height = int(cur_height / ratio)
|
|
763
|
+
resized_width = int(cur_width / ratio)
|
|
764
|
+
resized_img = F.interpolate(
|
|
765
|
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
pad_height = max(0, int(height - resized_height))
|
|
769
|
+
pad_width = max(0, int(width - resized_width))
|
|
770
|
+
|
|
771
|
+
# pad on left and top of image
|
|
772
|
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
|
773
|
+
|
|
774
|
+
# rearrange back to (c, h, w)
|
|
775
|
+
padded_img = rearrange(padded_img, "1 c h w -> c h w")
|
|
776
|
+
return padded_img
|
|
777
|
+
|
|
778
|
+
@staticmethod
|
|
779
|
+
def pad_vector(vector, new_dim):
|
|
780
|
+
"""Only the last dimension of the vector is padded to 'new_dim' with zeros."""
|
|
781
|
+
if vector.shape[-1] == new_dim:
|
|
782
|
+
return vector
|
|
783
|
+
shape = list(vector.shape)
|
|
784
|
+
current_dim = shape[-1]
|
|
785
|
+
shape[-1] = new_dim
|
|
786
|
+
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
|
787
|
+
new_vector[..., :current_dim] = vector
|
|
788
|
+
return new_vector
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
class LeRobotDataset(BaseDataset):
|
|
792
|
+
"""Main dataset class for loading and managing robot learning data.
|
|
793
|
+
|
|
794
|
+
This class provides a PyTorch Dataset interface for robot learning datasets
|
|
795
|
+
stored in the LeRobot format. It supports loading from HuggingFace Hub or
|
|
796
|
+
local disk, handles temporal alignment with delta timestamps, manages video
|
|
797
|
+
and image data, and provides data recording capabilities.
|
|
798
|
+
|
|
799
|
+
The dataset structure consists of:
|
|
800
|
+
- Metadata: JSON files containing info, statistics, episodes, tasks
|
|
801
|
+
- Data files: Parquet files organized by chunks containing episode data
|
|
802
|
+
- Videos: Optional MP4 files for camera observations
|
|
803
|
+
|
|
804
|
+
Key Features:
|
|
805
|
+
- Hub integration: Automatic download from HuggingFace Hub with version
|
|
806
|
+
compatibility checking
|
|
807
|
+
- Temporal alignment: Delta timestamps enable sampling features at
|
|
808
|
+
different time offsets with optional Gaussian noise for augmentation
|
|
809
|
+
- Video/image handling: Supports both video files and individual images
|
|
810
|
+
with automatic frame extraction and synchronization
|
|
811
|
+
- Episode filtering: Load specific episodes by index
|
|
812
|
+
- Data recording: Create new datasets and add episodes programmatically
|
|
813
|
+
- Statistics: Per-episode and aggregated statistics for normalization
|
|
814
|
+
|
|
815
|
+
Two Usage Modes:
|
|
816
|
+
1. Loading existing datasets: From local disk or HuggingFace Hub
|
|
817
|
+
2. Creating new datasets: Using the `create()` classmethod for data
|
|
818
|
+
recording
|
|
819
|
+
|
|
820
|
+
Attributes:
|
|
821
|
+
cfg: Training pipeline configuration.
|
|
822
|
+
repo_id: Repository ID of the dataset.
|
|
823
|
+
root: Local root directory for the dataset.
|
|
824
|
+
meta: LeRobotDatasetMetadata instance containing all metadata.
|
|
825
|
+
hf_dataset: HuggingFace Dataset containing parquet data.
|
|
826
|
+
episodes: Dictionary mapping episode_index to episode info.
|
|
827
|
+
image_transforms: Optional image transforms to apply.
|
|
828
|
+
delta_timestamps_params: Processed delta timestamp parameters.
|
|
829
|
+
feature2group: Mapping from features to temporal groups.
|
|
830
|
+
video_backend: Backend used for video decoding.
|
|
831
|
+
standardize: Whether to standardize data format.
|
|
832
|
+
|
|
833
|
+
Example:
|
|
834
|
+
Load dataset from Hub:
|
|
835
|
+
>>> dataset = LeRobotDataset(cfg, repo_id="lerobot/aloha")
|
|
836
|
+
>>> dataloader = DataLoader(dataset, batch_size=32)
|
|
837
|
+
|
|
838
|
+
Load specific episodes:
|
|
839
|
+
>>> dataset = LeRobotDataset(
|
|
840
|
+
... cfg,
|
|
841
|
+
... repo_id="lerobot/aloha",
|
|
842
|
+
... episodes=[0, 1, 2, 5, 10]
|
|
843
|
+
... )
|
|
844
|
+
|
|
845
|
+
Create new dataset for recording:
|
|
846
|
+
>>> dataset = LeRobotDataset.create(
|
|
847
|
+
... cfg,
|
|
848
|
+
... repo_id="my-new-dataset",
|
|
849
|
+
... fps=30,
|
|
850
|
+
... features={"state": {"dtype": "float32", "shape": (7,)}}
|
|
851
|
+
... )
|
|
852
|
+
"""
|
|
853
|
+
|
|
854
|
+
def __init__(
|
|
855
|
+
self,
|
|
856
|
+
cfg: TrainPipelineConfig,
|
|
857
|
+
repo_id: str,
|
|
858
|
+
root: str | Path | None = None,
|
|
859
|
+
episodes: list[int] | None = None,
|
|
860
|
+
image_transforms: Callable | None = None,
|
|
861
|
+
delta_timestamps: dict[str, np.ndarray | list[float]] | None = None,
|
|
862
|
+
delta_timestamps_std: dict[str, np.ndarray | list[float]] | None = None,
|
|
863
|
+
delta_timestamps_lower: dict[str, np.ndarray | list[float]] | None = None,
|
|
864
|
+
delta_timestamps_upper: dict[str, np.ndarray | list[float]] | None = None,
|
|
865
|
+
feature2group: dict[str, tuple[str, (list[int] | int | None)]] | None = None,
|
|
866
|
+
tolerance_s: float = 1e-4,
|
|
867
|
+
revision: str | None = None,
|
|
868
|
+
force_cache_sync: bool = False,
|
|
869
|
+
download_videos: bool = True,
|
|
870
|
+
video_backend: str | None = None,
|
|
871
|
+
image_resample_strategy: str = "nearest",
|
|
872
|
+
vector_resample_strategy: str = "nearest",
|
|
873
|
+
standardize: bool = True,
|
|
874
|
+
return_advantage_input: bool = False,
|
|
875
|
+
):
|
|
876
|
+
"""Initialize LeRobotDataset.
|
|
877
|
+
|
|
878
|
+
2 modes are available for instantiating this class, depending on 2 different use cases:
|
|
879
|
+
|
|
880
|
+
1. Your dataset already exists:
|
|
881
|
+
|
|
882
|
+
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
|
883
|
+
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
|
884
|
+
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
|
885
|
+
internet connection).
|
|
886
|
+
|
|
887
|
+
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
|
888
|
+
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
|
889
|
+
the dataset from that address and load it, pending your dataset is compliant with
|
|
890
|
+
codebase_version v2.0. If your dataset has been created before this new format, you will be
|
|
891
|
+
prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at
|
|
892
|
+
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
|
|
893
|
+
|
|
894
|
+
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
|
|
895
|
+
LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an
|
|
896
|
+
existing dataset to the LeRobotDataset format.
|
|
897
|
+
|
|
898
|
+
In terms of files, LeRobotDataset encapsulates 3 main things:
|
|
899
|
+
|
|
900
|
+
- metadata:
|
|
901
|
+
|
|
902
|
+
- info contains various information about the dataset like shapes, keys, fps etc.
|
|
903
|
+
- stats stores the dataset statistics of the different modalities for normalization
|
|
904
|
+
- tasks contains the prompts for each task of the dataset, which can be used for
|
|
905
|
+
task-conditioned training.
|
|
906
|
+
|
|
907
|
+
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
|
908
|
+
|
|
909
|
+
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
|
910
|
+
|
|
911
|
+
A typical LeRobotDataset looks like this from its root path::
|
|
912
|
+
|
|
913
|
+
.
|
|
914
|
+
├── data
|
|
915
|
+
│ ├── chunk-000
|
|
916
|
+
│ │ ├── episode_000000.parquet
|
|
917
|
+
│ │ ├── episode_000001.parquet
|
|
918
|
+
│ │ ├── episode_000002.parquet
|
|
919
|
+
│ │ └── ...
|
|
920
|
+
│ ├── chunk-001
|
|
921
|
+
│ │ ├── episode_001000.parquet
|
|
922
|
+
│ │ ├── episode_001001.parquet
|
|
923
|
+
│ │ ├── episode_001002.parquet
|
|
924
|
+
│ │ └── ...
|
|
925
|
+
│ └── ...
|
|
926
|
+
├── meta
|
|
927
|
+
│ ├── episodes.jsonl
|
|
928
|
+
│ ├── info.json
|
|
929
|
+
│ ├── stats.json
|
|
930
|
+
│ └── tasks.jsonl
|
|
931
|
+
└── videos
|
|
932
|
+
├── chunk-000
|
|
933
|
+
│ ├── observation.images.laptop
|
|
934
|
+
│ │ ├── episode_000000.mp4
|
|
935
|
+
│ │ ├── episode_000001.mp4
|
|
936
|
+
│ │ ├── episode_000002.mp4
|
|
937
|
+
│ │ └── ...
|
|
938
|
+
│ ├── observation.images.phone
|
|
939
|
+
│ │ ├── episode_000000.mp4
|
|
940
|
+
│ │ ├── episode_000001.mp4
|
|
941
|
+
│ │ ├── episode_000002.mp4
|
|
942
|
+
│ │ └── ...
|
|
943
|
+
├── chunk-001
|
|
944
|
+
└── ...
|
|
945
|
+
|
|
946
|
+
Note that this file-based structure is designed to be as versatile as possible. The files are split by
|
|
947
|
+
episodes which allows a more granular control over which episodes one wants to use and download. The
|
|
948
|
+
structure of the dataset is entirely described in the info.json file, which can be easily downloaded
|
|
949
|
+
or viewed directly on the hub before downloading any actual data. The type of files used are very
|
|
950
|
+
simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
|
|
951
|
+
for the README).
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
cfg (TrainPipelineConfig): Training configuration object.
|
|
955
|
+
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
|
956
|
+
will be stored under root/repo_id.
|
|
957
|
+
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
|
958
|
+
set the HF_OPENTAU_HOME environment variable to point to a different location. Defaults to
|
|
959
|
+
'~/.cache/huggingface/opentau'.
|
|
960
|
+
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
|
961
|
+
their episode_index in this list. Defaults to None.
|
|
962
|
+
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
|
963
|
+
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
|
964
|
+
from videos or images). Defaults to None.
|
|
965
|
+
delta_timestamps (dict[list[float]] | None, optional): Dictionary where each key is a group name and its
|
|
966
|
+
corresponding value is a list of delta timestamps in seconds. For example, {'group1': [0, 0.1]} means
|
|
967
|
+
features of group1 will be returned as a chunk of 2, with the first element being the value at current
|
|
968
|
+
time and the second element being the value at current time + 0.1 seconds. This will also add a key
|
|
969
|
+
named '{feature}_is_pad' to the returned item, with a boolean type and a length of 2, indicating whether
|
|
970
|
+
the feature is padded or not. Padding will happen when t + 0.1 is outside the episode time range.
|
|
971
|
+
Defaults to None.
|
|
972
|
+
delta_timestamps_std: (dict[list[float]] | None, optional): Similar to delta_timestamps, but specifies an
|
|
973
|
+
optional standard deviation for the delta timestamps. If a key is absent, the delta timestamps for that
|
|
974
|
+
key will be deterministic. If a key is present without corresponding delta_timestamps, it will be
|
|
975
|
+
ignored. E.g., delta_timestamps={'group1': [0, 0.1]} and delta_timestamps_std={'group1': [0, 0.05]} will
|
|
976
|
+
result in a chunk of 2, with the first element being the feature at current time and the second
|
|
977
|
+
element at a time following a Gaussian distribution with N(t+0.1, 0.05^2). When it takes on a value
|
|
978
|
+
outside the episode, the corresponding element in `{feature}_is_mask` will be set to True.
|
|
979
|
+
Defaults to None.
|
|
980
|
+
delta_timestamps_lower: (dict[list[float]] | None, optional): Similar to delta_timestamps_std, but specifies
|
|
981
|
+
a minimum value for the delta timestamps. When specified, the delta timestamps will be lower-clipped
|
|
982
|
+
accordingly. Defaults to None.
|
|
983
|
+
delta_timestamps_upper: (dict[list[float]] | None, optional): Similar to delta_timestamps_std, but specifies
|
|
984
|
+
a maximum value for the delta timestamps. When specified, the delta timestamps will be upper-clipped
|
|
985
|
+
accordingly. Defaults to None.
|
|
986
|
+
feature2group: (dict[str, tuple[str, (list[int] | int | None)]] | None, optional): Dictionary mapping every
|
|
987
|
+
individual feature to a tuple of (group name, indices). Group names are keys passed to delta_timestamps.
|
|
988
|
+
If `indices` is None, will use all indices in the group. If indices is a list, will use only those
|
|
989
|
+
indices in the corresponding order, including duplicates if present. If indices is an int, will return
|
|
990
|
+
that index only, resulting in a reduction in ndim by 1.
|
|
991
|
+
For example, `feature2group={'action': ('group1', None), 'observation.state': ('group2', 0),
|
|
992
|
+
'observation.images.left_hand': ('group2', [0, 1])}` means the feature `action` will use resolved
|
|
993
|
+
`delta_timestamps` from `group1` and will return every index. Also, `observation.state` will pick the
|
|
994
|
+
first element (index-0) of `group2`. `observation.images.left_hand` will pick the first and second
|
|
995
|
+
elements (indices 0 and 1) of `group2` and return them as a chunk of 2 images. The first element of
|
|
996
|
+
`observation.images.left_hand` and the state vector will always be sampled at the same timestamp despite
|
|
997
|
+
having gaussian noise applied, because they are in the same group.
|
|
998
|
+
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
|
|
999
|
+
sync with the fps value. It is used at the init of the dataset to make sure that each
|
|
1000
|
+
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
|
1001
|
+
decoded from video files. Defaults to 1e-4.
|
|
1002
|
+
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
|
1003
|
+
commit hash. Defaults to current codebase version tag.
|
|
1004
|
+
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
|
1005
|
+
video files are already present on local disk, they won't be downloaded again. Defaults to
|
|
1006
|
+
True.
|
|
1007
|
+
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
|
1008
|
+
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
|
1009
|
+
image_resample_strategy: str: Resampling strategy to use for image features.
|
|
1010
|
+
If 'linear', it will use linear interpolation between two immediate timestamps.
|
|
1011
|
+
If 'nearest', it will use nearest neighbor interpolation.
|
|
1012
|
+
Defaults to 'nearest'.
|
|
1013
|
+
vector_resample_strategy: str: Resampling strategy to use for non-image features, such as action or state.
|
|
1014
|
+
If 'linear', it will use linear interpolation between two immediate timestamps.
|
|
1015
|
+
If 'nearest', it will use nearest neighbor interpolation.
|
|
1016
|
+
Defaults to 'nearest'.
|
|
1017
|
+
standardize (bool, Optional): Flag to enable standardization in `__getitem__`. Defaults to True.
|
|
1018
|
+
return_advantage_input (bool, Optional): Flag to return advantage inputs ("success", "episode_end_idx", "current_idx", "last_step", "episode_index", "timestamp", ). Defaults to False. Ignored if standardize is False.
|
|
1019
|
+
"""
|
|
1020
|
+
super().__init__(cfg)
|
|
1021
|
+
self.cfg = cfg
|
|
1022
|
+
self.repo_id = repo_id
|
|
1023
|
+
self.root = Path(root) if root else HF_OPENTAU_HOME / repo_id
|
|
1024
|
+
self.image_transforms = image_transforms
|
|
1025
|
+
if bool(delta_timestamps) ^ bool(feature2group):
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
"Either both delta_timestamps and feature2group should be provided, or neither of them."
|
|
1028
|
+
)
|
|
1029
|
+
# delta_timestamps_params is a 4 tuple (mean, std, lower, upper)
|
|
1030
|
+
self.delta_timestamps_params = self.compute_delta_params(
|
|
1031
|
+
delta_timestamps,
|
|
1032
|
+
delta_timestamps_std,
|
|
1033
|
+
delta_timestamps_lower,
|
|
1034
|
+
delta_timestamps_upper,
|
|
1035
|
+
)
|
|
1036
|
+
self.feature2group = feature2group or {}
|
|
1037
|
+
self._check_feature_group_mapping()
|
|
1038
|
+
self.episodes = episodes
|
|
1039
|
+
self.tolerance_s = tolerance_s
|
|
1040
|
+
self.revision = revision if revision else CODEBASE_VERSION
|
|
1041
|
+
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
|
1042
|
+
|
|
1043
|
+
if image_resample_strategy not in ["linear", "nearest"]:
|
|
1044
|
+
raise ValueError(
|
|
1045
|
+
f"Invalid image resample strategy: {image_resample_strategy}. Choose 'linear' or 'nearest'."
|
|
1046
|
+
)
|
|
1047
|
+
if vector_resample_strategy not in ["linear", "nearest"]:
|
|
1048
|
+
raise ValueError(
|
|
1049
|
+
f"Invalid action resample strategy: {vector_resample_strategy}. Choose 'linear' or 'nearest'."
|
|
1050
|
+
)
|
|
1051
|
+
self.image_resample_strategy = image_resample_strategy
|
|
1052
|
+
self.vector_resample_strategy = vector_resample_strategy
|
|
1053
|
+
|
|
1054
|
+
self.standardize = standardize
|
|
1055
|
+
if return_advantage_input and not standardize:
|
|
1056
|
+
print(
|
|
1057
|
+
"Warning: `return_advantage_input` is True while `standardize` is False. "
|
|
1058
|
+
"No advantage inputs will be returned."
|
|
1059
|
+
)
|
|
1060
|
+
self.return_advantage_input = return_advantage_input
|
|
1061
|
+
|
|
1062
|
+
# Unused attributes
|
|
1063
|
+
self.image_writer = None
|
|
1064
|
+
self.episode_buffer = None
|
|
1065
|
+
|
|
1066
|
+
self.root.mkdir(exist_ok=True, parents=True)
|
|
1067
|
+
|
|
1068
|
+
# Load metadata
|
|
1069
|
+
self.meta = LeRobotDatasetMetadata(
|
|
1070
|
+
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
|
1071
|
+
)
|
|
1072
|
+
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
|
1073
|
+
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
|
1074
|
+
self.stats = aggregate_stats(episodes_stats)
|
|
1075
|
+
|
|
1076
|
+
if self.episodes is None:
|
|
1077
|
+
self.episodes = list(self.meta.episodes)
|
|
1078
|
+
|
|
1079
|
+
# Load actual data
|
|
1080
|
+
try:
|
|
1081
|
+
if force_cache_sync:
|
|
1082
|
+
raise FileNotFoundError
|
|
1083
|
+
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
|
1084
|
+
self.hf_dataset = self.load_hf_dataset()
|
|
1085
|
+
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
|
1086
|
+
self.revision = get_safe_version(self.repo_id, self.revision)
|
|
1087
|
+
self.download_episodes(download_videos)
|
|
1088
|
+
self.hf_dataset = self.load_hf_dataset()
|
|
1089
|
+
|
|
1090
|
+
self.episode_data_index, self.epi2idx = get_episode_data_index(self.meta.episodes, self.episodes)
|
|
1091
|
+
|
|
1092
|
+
# Check timestamps
|
|
1093
|
+
# If transform is set, with_transform will decode all columns of a row before returning the desired column(s).
|
|
1094
|
+
no_transform_ds = self.hf_dataset.with_transform(None).with_format("numpy")
|
|
1095
|
+
logging.info("Checking timestamps synchronization...")
|
|
1096
|
+
timestamps = np.asarray(no_transform_ds["timestamp"], dtype=np.float32)
|
|
1097
|
+
episode_indices = np.asarray(no_transform_ds["episode_index"], dtype=np.int64)
|
|
1098
|
+
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
|
1099
|
+
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
|
1100
|
+
|
|
1101
|
+
@on_accelerate_main_proc(local=True, _sync=True)
|
|
1102
|
+
def push_to_hub(
|
|
1103
|
+
self,
|
|
1104
|
+
branch: str | None = None,
|
|
1105
|
+
tags: list | None = None,
|
|
1106
|
+
license: str | None = "apache-2.0",
|
|
1107
|
+
tag_version: bool = True,
|
|
1108
|
+
push_videos: bool = True,
|
|
1109
|
+
private: bool = False,
|
|
1110
|
+
allow_patterns: list[str] | str | None = None,
|
|
1111
|
+
upload_large_folder: bool = False,
|
|
1112
|
+
**card_kwargs,
|
|
1113
|
+
) -> None:
|
|
1114
|
+
ignore_patterns = ["images/"]
|
|
1115
|
+
if not push_videos:
|
|
1116
|
+
ignore_patterns.append("videos/")
|
|
1117
|
+
|
|
1118
|
+
hub_api = HfApi()
|
|
1119
|
+
hub_api.create_repo(
|
|
1120
|
+
repo_id=self.repo_id,
|
|
1121
|
+
private=private,
|
|
1122
|
+
repo_type="dataset",
|
|
1123
|
+
exist_ok=True,
|
|
1124
|
+
)
|
|
1125
|
+
if branch:
|
|
1126
|
+
hub_api.create_branch(
|
|
1127
|
+
repo_id=self.repo_id,
|
|
1128
|
+
branch=branch,
|
|
1129
|
+
revision=self.revision,
|
|
1130
|
+
repo_type="dataset",
|
|
1131
|
+
exist_ok=True,
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
upload_kwargs = {
|
|
1135
|
+
"repo_id": self.repo_id,
|
|
1136
|
+
"folder_path": self.root,
|
|
1137
|
+
"repo_type": "dataset",
|
|
1138
|
+
"revision": branch,
|
|
1139
|
+
"allow_patterns": allow_patterns,
|
|
1140
|
+
"ignore_patterns": ignore_patterns,
|
|
1141
|
+
}
|
|
1142
|
+
if upload_large_folder:
|
|
1143
|
+
hub_api.upload_large_folder(**upload_kwargs)
|
|
1144
|
+
else:
|
|
1145
|
+
hub_api.upload_folder(**upload_kwargs)
|
|
1146
|
+
|
|
1147
|
+
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
|
1148
|
+
card = create_lerobot_dataset_card(
|
|
1149
|
+
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
|
1150
|
+
)
|
|
1151
|
+
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
|
1152
|
+
|
|
1153
|
+
if tag_version:
|
|
1154
|
+
with contextlib.suppress(RevisionNotFoundError):
|
|
1155
|
+
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
|
1156
|
+
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
|
1157
|
+
|
|
1158
|
+
@on_accelerate_main_proc(local=True, _sync=True)
|
|
1159
|
+
def pull_from_repo(
|
|
1160
|
+
self,
|
|
1161
|
+
allow_patterns: list[str] | str | None = None,
|
|
1162
|
+
ignore_patterns: list[str] | str | None = None,
|
|
1163
|
+
) -> None:
|
|
1164
|
+
snapshot_download(
|
|
1165
|
+
self.repo_id,
|
|
1166
|
+
repo_type="dataset",
|
|
1167
|
+
revision=self.revision,
|
|
1168
|
+
local_dir=self.root,
|
|
1169
|
+
allow_patterns=allow_patterns,
|
|
1170
|
+
ignore_patterns=ignore_patterns,
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
def download_episodes(self, download_videos: bool = True) -> None:
|
|
1174
|
+
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
|
1175
|
+
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
|
1176
|
+
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
|
1177
|
+
in 'local_dir', they won't be downloaded again.
|
|
1178
|
+
"""
|
|
1179
|
+
# TODO(rcadene, aliberts): implement faster transfer
|
|
1180
|
+
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
|
1181
|
+
files = None
|
|
1182
|
+
ignore_patterns = None if download_videos else "videos/"
|
|
1183
|
+
if self.episodes is not None:
|
|
1184
|
+
files = self.get_episodes_file_paths()
|
|
1185
|
+
|
|
1186
|
+
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
|
1187
|
+
|
|
1188
|
+
def get_episodes_file_paths(self) -> list[Path]:
|
|
1189
|
+
"""Get file paths for all selected episodes.
|
|
1190
|
+
|
|
1191
|
+
Returns paths for both parquet data files and video files (if applicable)
|
|
1192
|
+
for all episodes in the dataset.
|
|
1193
|
+
|
|
1194
|
+
Returns:
|
|
1195
|
+
List of file paths for episode data and videos.
|
|
1196
|
+
"""
|
|
1197
|
+
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
|
|
1198
|
+
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
|
1199
|
+
if len(self.meta.video_keys) > 0:
|
|
1200
|
+
video_files = [
|
|
1201
|
+
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
|
1202
|
+
for vid_key in self.meta.video_keys
|
|
1203
|
+
for ep_idx in episodes
|
|
1204
|
+
]
|
|
1205
|
+
fpaths += video_files
|
|
1206
|
+
|
|
1207
|
+
return fpaths
|
|
1208
|
+
|
|
1209
|
+
def load_hf_dataset(self) -> datasets.Dataset:
|
|
1210
|
+
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
|
1211
|
+
if self.episodes is None:
|
|
1212
|
+
path = str(self.root / "data")
|
|
1213
|
+
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
|
1214
|
+
else:
|
|
1215
|
+
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
|
1216
|
+
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
|
1217
|
+
|
|
1218
|
+
# TODO(aliberts): hf_dataset.set_format("torch")
|
|
1219
|
+
hf_dataset.set_transform(hf_transform_to_torch)
|
|
1220
|
+
return hf_dataset
|
|
1221
|
+
|
|
1222
|
+
def create_hf_dataset(self) -> datasets.Dataset:
|
|
1223
|
+
"""Create an empty HuggingFace dataset with the correct features.
|
|
1224
|
+
|
|
1225
|
+
Returns:
|
|
1226
|
+
Empty dataset with features matching the dataset specification.
|
|
1227
|
+
"""
|
|
1228
|
+
features = get_hf_features_from_features(self.features)
|
|
1229
|
+
ft_dict = {col: [] for col in features}
|
|
1230
|
+
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
|
1231
|
+
|
|
1232
|
+
# TODO(aliberts): hf_dataset.set_format("torch")
|
|
1233
|
+
hf_dataset.set_transform(hf_transform_to_torch)
|
|
1234
|
+
return hf_dataset
|
|
1235
|
+
|
|
1236
|
+
@property
|
|
1237
|
+
def fps(self) -> int:
|
|
1238
|
+
"""Frames per second used during data collection."""
|
|
1239
|
+
return self.meta.fps
|
|
1240
|
+
|
|
1241
|
+
@property
|
|
1242
|
+
def num_frames(self) -> int:
|
|
1243
|
+
"""Number of frames in selected episodes."""
|
|
1244
|
+
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
|
1245
|
+
|
|
1246
|
+
@property
|
|
1247
|
+
def num_episodes(self) -> int:
|
|
1248
|
+
"""Number of episodes selected."""
|
|
1249
|
+
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
|
1250
|
+
|
|
1251
|
+
@property
|
|
1252
|
+
def features(self) -> dict[str, dict]:
|
|
1253
|
+
return self.meta.features
|
|
1254
|
+
|
|
1255
|
+
@property
|
|
1256
|
+
def hf_features(self) -> datasets.Features:
|
|
1257
|
+
"""Features of the hf_dataset."""
|
|
1258
|
+
if self.hf_dataset is not None:
|
|
1259
|
+
return self.hf_dataset.features
|
|
1260
|
+
else:
|
|
1261
|
+
return get_hf_features_from_features(self.features)
|
|
1262
|
+
|
|
1263
|
+
def _get_query_indices_soft(
|
|
1264
|
+
self, idx: int, ep_idx: int
|
|
1265
|
+
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
|
|
1266
|
+
"""Get soft (float) indices for querying features with delta timestamps.
|
|
1267
|
+
|
|
1268
|
+
Computes indices for features based on delta timestamps, accounting for
|
|
1269
|
+
episode boundaries. Returns both query indices and padding masks.
|
|
1270
|
+
|
|
1271
|
+
Args:
|
|
1272
|
+
idx: Current data index.
|
|
1273
|
+
ep_idx: Current episode index.
|
|
1274
|
+
|
|
1275
|
+
Returns:
|
|
1276
|
+
Tuple of (query_indices, padding):
|
|
1277
|
+
- query_indices: Dictionary mapping feature names to soft indices.
|
|
1278
|
+
- padding: Dictionary mapping feature names to boolean padding masks.
|
|
1279
|
+
"""
|
|
1280
|
+
ep_start = self.episode_data_index["from"][self.epi2idx[ep_idx]].item()
|
|
1281
|
+
ep_end = self.episode_data_index["to"][self.epi2idx[ep_idx]].item()
|
|
1282
|
+
|
|
1283
|
+
# Get the delta_indices by group
|
|
1284
|
+
delta_indices = get_delta_indices_soft(self.delta_timestamps_params, self.fps)
|
|
1285
|
+
# Map from group to feature
|
|
1286
|
+
delta_indices = {
|
|
1287
|
+
feature: delta_indices[group][
|
|
1288
|
+
slice(None) if indices is None else [indices] if isinstance(indices, int) else indices
|
|
1289
|
+
]
|
|
1290
|
+
for feature, (group, indices) in self.feature2group.items()
|
|
1291
|
+
}
|
|
1292
|
+
query_indices = {
|
|
1293
|
+
key: np.clip(idx + delta_idx, ep_start, ep_end - 1) for key, delta_idx in delta_indices.items()
|
|
1294
|
+
}
|
|
1295
|
+
padding = { # Pad values outside of current episode range
|
|
1296
|
+
f"{key}_is_pad": torch.BoolTensor((idx + delta_idx < ep_start) | (idx + delta_idx >= ep_end))
|
|
1297
|
+
for key, delta_idx in delta_indices.items()
|
|
1298
|
+
}
|
|
1299
|
+
return query_indices, padding
|
|
1300
|
+
|
|
1301
|
+
def _get_query_timestamps(
|
|
1302
|
+
self,
|
|
1303
|
+
current_ts: float,
|
|
1304
|
+
query_indices: dict[str, np.ndarray] | None = None,
|
|
1305
|
+
) -> dict[str, np.ndarray]:
|
|
1306
|
+
"""Get query timestamps for video features.
|
|
1307
|
+
|
|
1308
|
+
Converts soft indices to timestamps for video frame extraction.
|
|
1309
|
+
If query_indices is provided, uses them; otherwise uses current timestamp.
|
|
1310
|
+
|
|
1311
|
+
Args:
|
|
1312
|
+
current_ts: Current timestamp in seconds.
|
|
1313
|
+
query_indices: Optional dictionary of soft indices for features.
|
|
1314
|
+
|
|
1315
|
+
Returns:
|
|
1316
|
+
Dictionary mapping video keys to query timestamps.
|
|
1317
|
+
"""
|
|
1318
|
+
if query_indices:
|
|
1319
|
+
# In case values are lists
|
|
1320
|
+
query_indices = {k: np.array(v, dtype=np.float32) for k, v in query_indices.items()}
|
|
1321
|
+
q_indices = next(iter(query_indices.values()))
|
|
1322
|
+
# Pick any (soft) row index, which is guaranteed to be within [ep_start, ep_end), then take the floor
|
|
1323
|
+
in_ep_row_idx = math.floor(q_indices[0])
|
|
1324
|
+
# Index of the episode (not index of row). E.g., episode_index = 36 for row index = 10000
|
|
1325
|
+
ep_idx = self.hf_dataset.select([in_ep_row_idx])["episode_index"][0].item()
|
|
1326
|
+
# Row index where the current episode start
|
|
1327
|
+
ep_start_row_idx = self.episode_data_index["from"][self.epi2idx[ep_idx]].item()
|
|
1328
|
+
else:
|
|
1329
|
+
ep_start_row_idx = None
|
|
1330
|
+
|
|
1331
|
+
query_timestamps = {}
|
|
1332
|
+
for key in self.meta.video_keys:
|
|
1333
|
+
if query_indices is not None and key in query_indices:
|
|
1334
|
+
query_timestamps[key] = (query_indices[key] - ep_start_row_idx) / self.fps
|
|
1335
|
+
else:
|
|
1336
|
+
query_timestamps[key] = np.array([current_ts], dtype=np.float32)
|
|
1337
|
+
|
|
1338
|
+
return query_timestamps
|
|
1339
|
+
|
|
1340
|
+
def _query_hf_dataset_soft(self, soft_indices: dict[str, np.ndarray]) -> dict:
|
|
1341
|
+
"""Query dataset using soft (float) indices with interpolation.
|
|
1342
|
+
|
|
1343
|
+
Converts soft indices to hard indices based on resample strategy
|
|
1344
|
+
(linear interpolation or nearest neighbor).
|
|
1345
|
+
|
|
1346
|
+
Args:
|
|
1347
|
+
soft_indices: Dictionary mapping feature names to soft (float) indices.
|
|
1348
|
+
|
|
1349
|
+
Returns:
|
|
1350
|
+
Dictionary of feature values queried from the dataset.
|
|
1351
|
+
|
|
1352
|
+
Raises:
|
|
1353
|
+
ValueError: If vector_resample_strategy is not 'linear' or 'nearest'.
|
|
1354
|
+
"""
|
|
1355
|
+
# soft indices are float indices that need to be converted to hard (integer) indices
|
|
1356
|
+
if self.vector_resample_strategy == "linear":
|
|
1357
|
+
floor_indices = {k: np.floor(v).astype(int) for k, v in soft_indices.items()}
|
|
1358
|
+
dist2floor = {k: v - floor_indices[k] for k, v in soft_indices.items()}
|
|
1359
|
+
# In the unlikely case that the soft index is exactly (ep_end - 1), floor will (ep_end - 1), and (floor + 1)
|
|
1360
|
+
# will be ep_end, which may be out of bounds (despite usually being the start of the next episode).
|
|
1361
|
+
# Therefore, we add 0 instead of 1 whenever the distance to floor is 0.
|
|
1362
|
+
ceil_indices = {k: floor_indices[k] + (dist2floor[k] > 0.0) for k, v in soft_indices.items()}
|
|
1363
|
+
q_floor = self._query_hf_dataset(floor_indices)
|
|
1364
|
+
q_ceil = self._query_hf_dataset(ceil_indices)
|
|
1365
|
+
|
|
1366
|
+
item = {}
|
|
1367
|
+
for k, d2f in dist2floor.items():
|
|
1368
|
+
if k not in q_floor:
|
|
1369
|
+
continue
|
|
1370
|
+
d2f = torch.tensor(d2f)
|
|
1371
|
+
d2f = rearrange(d2f, f"n -> {'n' + ' 1' * (q_floor[k].ndim - 1)}")
|
|
1372
|
+
item[k] = (1.0 - d2f) * q_floor[k] + d2f * q_ceil[k]
|
|
1373
|
+
return item
|
|
1374
|
+
elif self.vector_resample_strategy == "nearest":
|
|
1375
|
+
hard_indices = {k: v.round().astype(int) for k, v in soft_indices.items()}
|
|
1376
|
+
return self._query_hf_dataset(hard_indices)
|
|
1377
|
+
|
|
1378
|
+
raise ValueError(
|
|
1379
|
+
f"Unsupported vector_resample_strategy: {self.vector_resample_strategy}. Choose 'linear' or 'nearest'."
|
|
1380
|
+
)
|
|
1381
|
+
|
|
1382
|
+
def _query_hf_dataset(self, hard_indices: dict[str, np.ndarray]) -> dict:
|
|
1383
|
+
"""Query dataset using hard (integer) indices.
|
|
1384
|
+
|
|
1385
|
+
Args:
|
|
1386
|
+
hard_indices: Dictionary mapping feature names to integer indices.
|
|
1387
|
+
|
|
1388
|
+
Returns:
|
|
1389
|
+
Dictionary of feature values stacked as tensors.
|
|
1390
|
+
"""
|
|
1391
|
+
# TODO(shuheng): look into optimization when using hf_dataset.select
|
|
1392
|
+
return {
|
|
1393
|
+
key: torch.stack(list(self.hf_dataset.select(q_idx)[key]))
|
|
1394
|
+
for key, q_idx in hard_indices.items()
|
|
1395
|
+
if key not in self.meta.video_keys
|
|
1396
|
+
}
|
|
1397
|
+
|
|
1398
|
+
def _query_videos(self, query_timestamps: dict[str, np.ndarray], ep_idx: int) -> dict[str, torch.Tensor]:
|
|
1399
|
+
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
|
1400
|
+
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
|
1401
|
+
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
|
1402
|
+
the main process and a subprocess fails to access it.
|
|
1403
|
+
"""
|
|
1404
|
+
item = {}
|
|
1405
|
+
for vid_key, query_ts in query_timestamps.items():
|
|
1406
|
+
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
|
1407
|
+
frame_indices_soft = query_ts * self.fps
|
|
1408
|
+
if self.image_resample_strategy == "linear":
|
|
1409
|
+
frame_indices_floor = np.floor(frame_indices_soft).astype(int)
|
|
1410
|
+
dist2floor = frame_indices_soft - frame_indices_floor
|
|
1411
|
+
frame_indices_ceil = np.floor(frame_indices_soft) + 1.0 * (dist2floor > 0.0)
|
|
1412
|
+
query_ts_floor = (frame_indices_floor / self.fps).tolist()
|
|
1413
|
+
query_ts_ceil = (frame_indices_ceil / self.fps).tolist()
|
|
1414
|
+
frames_floor = decode_video_frames(
|
|
1415
|
+
video_path, query_ts_floor, self.tolerance_s, self.video_backend
|
|
1416
|
+
)
|
|
1417
|
+
frames_ceil = decode_video_frames(
|
|
1418
|
+
video_path, query_ts_ceil, self.tolerance_s, self.video_backend
|
|
1419
|
+
)
|
|
1420
|
+
dist2floor = dist2floor[:, None, None, None]
|
|
1421
|
+
frames = frames_ceil * dist2floor + frames_floor * (1 - dist2floor)
|
|
1422
|
+
elif self.image_resample_strategy == "nearest":
|
|
1423
|
+
query_ts_rounded = (frame_indices_soft.round() / self.fps).tolist()
|
|
1424
|
+
frames = decode_video_frames(
|
|
1425
|
+
video_path, query_ts_rounded, self.tolerance_s, self.video_backend
|
|
1426
|
+
)
|
|
1427
|
+
else:
|
|
1428
|
+
raise ValueError(
|
|
1429
|
+
f"Unsupported image_resample_strategy: {self.image_resample_strategy}. Choose 'linear' or 'nearest'."
|
|
1430
|
+
)
|
|
1431
|
+
item[vid_key] = frames.squeeze(0)
|
|
1432
|
+
|
|
1433
|
+
return item
|
|
1434
|
+
|
|
1435
|
+
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
|
1436
|
+
"""Add padding mask keys to the item dictionary.
|
|
1437
|
+
|
|
1438
|
+
Args:
|
|
1439
|
+
item: Item dictionary to modify.
|
|
1440
|
+
padding: Dictionary mapping feature names to boolean padding masks.
|
|
1441
|
+
|
|
1442
|
+
Returns:
|
|
1443
|
+
Modified item dictionary with padding keys added.
|
|
1444
|
+
"""
|
|
1445
|
+
for key, val in padding.items():
|
|
1446
|
+
item[key] = torch.BoolTensor(val)
|
|
1447
|
+
return item
|
|
1448
|
+
|
|
1449
|
+
def __len__(self):
|
|
1450
|
+
return self.num_frames
|
|
1451
|
+
|
|
1452
|
+
@retry_random_on_failure
|
|
1453
|
+
def __getitem__(self, idx) -> dict:
|
|
1454
|
+
item = self.hf_dataset[idx]
|
|
1455
|
+
ep_idx = item["episode_index"].item()
|
|
1456
|
+
|
|
1457
|
+
if self.episode_data_index is not None and self.epi2idx is not None:
|
|
1458
|
+
ep_end = self.episode_data_index["to"][self.epi2idx[ep_idx]].item()
|
|
1459
|
+
|
|
1460
|
+
episodes_info = self.meta.episodes[ep_idx]
|
|
1461
|
+
|
|
1462
|
+
# Soft indices are floats instead of integers, which allows for different interpolation strategies such as
|
|
1463
|
+
# nearest neighbor or linear interpolation.
|
|
1464
|
+
query_indices_soft = None
|
|
1465
|
+
if self.delta_timestamps_params[0]:
|
|
1466
|
+
query_indices_soft, padding = self._get_query_indices_soft(idx, ep_idx)
|
|
1467
|
+
query_result = self._query_hf_dataset_soft(query_indices_soft)
|
|
1468
|
+
item = {**item, **padding}
|
|
1469
|
+
for key, val in query_result.items():
|
|
1470
|
+
item[key] = val
|
|
1471
|
+
|
|
1472
|
+
if len(self.meta.video_keys) > 0:
|
|
1473
|
+
current_ts = item["timestamp"].item()
|
|
1474
|
+
query_timestamps = self._get_query_timestamps(current_ts, query_indices_soft)
|
|
1475
|
+
video_frames = self._query_videos(query_timestamps, ep_idx)
|
|
1476
|
+
item = {**video_frames, **item}
|
|
1477
|
+
|
|
1478
|
+
if self.image_transforms is not None:
|
|
1479
|
+
image_keys = self.meta.camera_keys
|
|
1480
|
+
for cam in image_keys:
|
|
1481
|
+
item[cam] = self.image_transforms(item[cam])
|
|
1482
|
+
|
|
1483
|
+
# Add task as a string
|
|
1484
|
+
task_idx = item["task_index"].item()
|
|
1485
|
+
item["task"] = self.meta.tasks[task_idx]
|
|
1486
|
+
|
|
1487
|
+
# If indices is an int, squeeze the feature
|
|
1488
|
+
for feature, (_, indices) in self.feature2group.items():
|
|
1489
|
+
if isinstance(indices, int):
|
|
1490
|
+
item[feature] = item[feature].squeeze(0)
|
|
1491
|
+
|
|
1492
|
+
# The conversion script of AGI BOT dataset uses a dataloader to enumerate data and compute stats.
|
|
1493
|
+
# If we enable standardization, those stats will be computed under their mapped names, which is wrong.
|
|
1494
|
+
|
|
1495
|
+
if self.standardize:
|
|
1496
|
+
# Add response as a string
|
|
1497
|
+
if "response" not in item:
|
|
1498
|
+
item["response"] = ""
|
|
1499
|
+
|
|
1500
|
+
episode_index = item["episode_index"].item()
|
|
1501
|
+
# don't convert to timestamp to `float`, because torch.float64 is not supported on MPS
|
|
1502
|
+
timestamp = item["timestamp"]
|
|
1503
|
+
|
|
1504
|
+
# change data naming to standard data format
|
|
1505
|
+
item = self._to_standard_data_format(item)
|
|
1506
|
+
|
|
1507
|
+
if self.meta.advantages is not None:
|
|
1508
|
+
advantage = self.meta.advantages.get((episode_index, timestamp), 0)
|
|
1509
|
+
item["advantage"] = torch.tensor(advantage, dtype=torch.bfloat16)
|
|
1510
|
+
else:
|
|
1511
|
+
item["advantage"] = torch.tensor(0.0, dtype=torch.bfloat16)
|
|
1512
|
+
|
|
1513
|
+
success = episodes_info.get("success", True)
|
|
1514
|
+
|
|
1515
|
+
# only add the below fields to item when training or evaluating the value fns
|
|
1516
|
+
if isinstance(self.cfg.policy, ValueConfig):
|
|
1517
|
+
item["return_bin_idx"], item["return_continuous"] = calculate_return_bins_with_equal_width(
|
|
1518
|
+
success,
|
|
1519
|
+
self.cfg.policy.reward_config.number_of_bins,
|
|
1520
|
+
ep_end,
|
|
1521
|
+
self.cfg.policy.reward_config.reward_normalizer,
|
|
1522
|
+
idx,
|
|
1523
|
+
self.cfg.policy.reward_config.C_neg,
|
|
1524
|
+
)
|
|
1525
|
+
|
|
1526
|
+
item["return_bin_idx"] = torch.tensor(item["return_bin_idx"], dtype=torch.long)
|
|
1527
|
+
item["return_continuous"] = torch.tensor(item["return_continuous"], dtype=torch.float32)
|
|
1528
|
+
# success, episode_end_idx and last step is required for calculating advantage
|
|
1529
|
+
if self.return_advantage_input:
|
|
1530
|
+
item["success"] = success
|
|
1531
|
+
item["episode_end_idx"] = ep_end
|
|
1532
|
+
item["current_idx"] = idx
|
|
1533
|
+
item["last_step"] = idx + self.cfg.policy.reward_config.N_steps_look_ahead >= ep_end
|
|
1534
|
+
item["episode_index"] = episode_index
|
|
1535
|
+
item["timestamp"] = timestamp
|
|
1536
|
+
else:
|
|
1537
|
+
item["return_bin_idx"] = torch.tensor(0, dtype=torch.long)
|
|
1538
|
+
item["return_continuous"] = torch.tensor(0, dtype=torch.float32)
|
|
1539
|
+
|
|
1540
|
+
# sanity check for action chunk lengths
|
|
1541
|
+
assert item["actions"].shape[0] == self.cfg.action_chunk
|
|
1542
|
+
assert item["action_is_pad"].shape[0] == self.cfg.action_chunk
|
|
1543
|
+
|
|
1544
|
+
return item
|
|
1545
|
+
|
|
1546
|
+
def _get_feature_mapping_key(self) -> str:
|
|
1547
|
+
return self.repo_id
|
|
1548
|
+
|
|
1549
|
+
def __repr__(self):
|
|
1550
|
+
feature_keys = list(self.features)
|
|
1551
|
+
return (
|
|
1552
|
+
f"{self.__class__.__name__}({{\n"
|
|
1553
|
+
f" Repository ID: '{self.repo_id}',\n"
|
|
1554
|
+
f" Number of selected episodes: '{self.num_episodes}',\n"
|
|
1555
|
+
f" Number of selected samples: '{self.num_frames}',\n"
|
|
1556
|
+
f" Features: '{feature_keys}',\n"
|
|
1557
|
+
"})',\n"
|
|
1558
|
+
)
|
|
1559
|
+
|
|
1560
|
+
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
|
1561
|
+
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
|
1562
|
+
ep_buffer = {}
|
|
1563
|
+
# size and task are special cases that are not in self.features
|
|
1564
|
+
ep_buffer["size"] = 0
|
|
1565
|
+
ep_buffer["task"] = []
|
|
1566
|
+
for key in self.features:
|
|
1567
|
+
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
|
1568
|
+
return ep_buffer
|
|
1569
|
+
|
|
1570
|
+
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
|
1571
|
+
fpath = DEFAULT_IMAGE_PATH.format(
|
|
1572
|
+
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
|
1573
|
+
)
|
|
1574
|
+
return self.root / fpath
|
|
1575
|
+
|
|
1576
|
+
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
|
1577
|
+
if self.image_writer is None:
|
|
1578
|
+
if isinstance(image, torch.Tensor):
|
|
1579
|
+
image = image.cpu().numpy()
|
|
1580
|
+
write_image(image, fpath)
|
|
1581
|
+
else:
|
|
1582
|
+
self.image_writer.save_image(image=image, fpath=fpath)
|
|
1583
|
+
|
|
1584
|
+
def add_frame(self, frame: dict) -> None:
|
|
1585
|
+
"""
|
|
1586
|
+
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
|
1587
|
+
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
|
1588
|
+
then needs to be called.
|
|
1589
|
+
"""
|
|
1590
|
+
# Convert torch to numpy if needed
|
|
1591
|
+
for name in frame:
|
|
1592
|
+
if isinstance(frame[name], torch.Tensor):
|
|
1593
|
+
frame[name] = frame[name].numpy()
|
|
1594
|
+
|
|
1595
|
+
validate_frame(frame, self.features)
|
|
1596
|
+
|
|
1597
|
+
if self.episode_buffer is None:
|
|
1598
|
+
self.episode_buffer = self.create_episode_buffer()
|
|
1599
|
+
|
|
1600
|
+
# Automatically add frame_index and timestamp to episode buffer
|
|
1601
|
+
frame_index = self.episode_buffer["size"]
|
|
1602
|
+
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
|
1603
|
+
self.episode_buffer["frame_index"].append(frame_index)
|
|
1604
|
+
self.episode_buffer["timestamp"].append(timestamp)
|
|
1605
|
+
|
|
1606
|
+
# Add frame features to episode_buffer
|
|
1607
|
+
for key in frame:
|
|
1608
|
+
if key == "task":
|
|
1609
|
+
# Note: we associate the task in natural language to its task index during `save_episode`
|
|
1610
|
+
self.episode_buffer["task"].append(frame["task"])
|
|
1611
|
+
continue
|
|
1612
|
+
|
|
1613
|
+
if key not in self.features:
|
|
1614
|
+
raise ValueError(
|
|
1615
|
+
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
|
1616
|
+
)
|
|
1617
|
+
|
|
1618
|
+
if self.features[key]["dtype"] in ["image", "video"]:
|
|
1619
|
+
img_path = self._get_image_file_path(
|
|
1620
|
+
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
|
1621
|
+
)
|
|
1622
|
+
if frame_index == 0:
|
|
1623
|
+
img_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1624
|
+
self._save_image(frame[key], img_path)
|
|
1625
|
+
self.episode_buffer[key].append(str(img_path))
|
|
1626
|
+
else:
|
|
1627
|
+
self.episode_buffer[key].append(frame[key])
|
|
1628
|
+
|
|
1629
|
+
self.episode_buffer["size"] += 1
|
|
1630
|
+
|
|
1631
|
+
def save_episode(self, episode_data: dict | None = None) -> None:
|
|
1632
|
+
"""
|
|
1633
|
+
This will save to disk the current episode in self.episode_buffer.
|
|
1634
|
+
|
|
1635
|
+
Args:
|
|
1636
|
+
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
|
1637
|
+
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
|
1638
|
+
None.
|
|
1639
|
+
"""
|
|
1640
|
+
if not episode_data:
|
|
1641
|
+
episode_buffer = self.episode_buffer
|
|
1642
|
+
|
|
1643
|
+
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
|
1644
|
+
|
|
1645
|
+
# size and task are special cases that won't be added to hf_dataset
|
|
1646
|
+
episode_length = episode_buffer.pop("size")
|
|
1647
|
+
tasks = episode_buffer.pop("task")
|
|
1648
|
+
episode_tasks = list(set(tasks))
|
|
1649
|
+
episode_index = episode_buffer["episode_index"]
|
|
1650
|
+
|
|
1651
|
+
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
|
1652
|
+
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
|
1653
|
+
|
|
1654
|
+
# Add new tasks to the tasks dictionary
|
|
1655
|
+
for task in episode_tasks:
|
|
1656
|
+
task_index = self.meta.get_task_index(task)
|
|
1657
|
+
if task_index is None:
|
|
1658
|
+
self.meta.add_task(task)
|
|
1659
|
+
|
|
1660
|
+
# Given tasks in natural language, find their corresponding task indices
|
|
1661
|
+
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
|
1662
|
+
|
|
1663
|
+
for key, ft in self.features.items():
|
|
1664
|
+
# index, episode_index, task_index are already processed above, and image and video
|
|
1665
|
+
# are processed separately by storing image path and frame info as meta data
|
|
1666
|
+
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
|
1667
|
+
continue
|
|
1668
|
+
episode_buffer[key] = np.stack(episode_buffer[key])
|
|
1669
|
+
|
|
1670
|
+
self._wait_image_writer()
|
|
1671
|
+
self._save_episode_table(episode_buffer, episode_index)
|
|
1672
|
+
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
|
1673
|
+
|
|
1674
|
+
if len(self.meta.video_keys) > 0:
|
|
1675
|
+
video_paths = self.encode_episode_videos(episode_index)
|
|
1676
|
+
for key in self.meta.video_keys:
|
|
1677
|
+
episode_buffer[key] = video_paths[key]
|
|
1678
|
+
|
|
1679
|
+
# `meta.save_episode` be executed after encoding the videos
|
|
1680
|
+
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
|
1681
|
+
|
|
1682
|
+
ep_data_index, _ = get_episode_data_index(self.meta.episodes, [episode_index])
|
|
1683
|
+
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
|
1684
|
+
check_timestamps_sync(
|
|
1685
|
+
episode_buffer["timestamp"],
|
|
1686
|
+
episode_buffer["episode_index"],
|
|
1687
|
+
ep_data_index_np,
|
|
1688
|
+
self.fps,
|
|
1689
|
+
self.tolerance_s,
|
|
1690
|
+
)
|
|
1691
|
+
|
|
1692
|
+
video_files = list(self.root.rglob("*.mp4"))
|
|
1693
|
+
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
|
1694
|
+
|
|
1695
|
+
parquet_files = list(self.root.rglob("*.parquet"))
|
|
1696
|
+
assert len(parquet_files) == self.num_episodes
|
|
1697
|
+
|
|
1698
|
+
# delete images
|
|
1699
|
+
img_dir = self.root / "images"
|
|
1700
|
+
if img_dir.is_dir():
|
|
1701
|
+
shutil.rmtree(self.root / "images")
|
|
1702
|
+
|
|
1703
|
+
if not episode_data: # Reset the buffer
|
|
1704
|
+
self.episode_buffer = self.create_episode_buffer()
|
|
1705
|
+
|
|
1706
|
+
self.episode_data_index, self.epi2idx = get_episode_data_index(self.meta.episodes, self.episodes)
|
|
1707
|
+
|
|
1708
|
+
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
|
1709
|
+
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
|
1710
|
+
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
|
1711
|
+
ep_dataset = embed_images(ep_dataset)
|
|
1712
|
+
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
|
1713
|
+
self.hf_dataset.set_transform(hf_transform_to_torch)
|
|
1714
|
+
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
|
1715
|
+
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1716
|
+
ep_dataset.to_parquet(ep_data_path)
|
|
1717
|
+
|
|
1718
|
+
def clear_episode_buffer(self) -> None:
|
|
1719
|
+
episode_index = self.episode_buffer["episode_index"]
|
|
1720
|
+
if self.image_writer is not None:
|
|
1721
|
+
for cam_key in self.meta.camera_keys:
|
|
1722
|
+
img_dir = self._get_image_file_path(
|
|
1723
|
+
episode_index=episode_index, image_key=cam_key, frame_index=0
|
|
1724
|
+
).parent
|
|
1725
|
+
if img_dir.is_dir():
|
|
1726
|
+
shutil.rmtree(img_dir)
|
|
1727
|
+
|
|
1728
|
+
# Reset the buffer
|
|
1729
|
+
self.episode_buffer = self.create_episode_buffer()
|
|
1730
|
+
|
|
1731
|
+
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
|
1732
|
+
if isinstance(self.image_writer, AsyncImageWriter):
|
|
1733
|
+
logging.warning(
|
|
1734
|
+
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
|
|
1735
|
+
)
|
|
1736
|
+
|
|
1737
|
+
self.image_writer = AsyncImageWriter(
|
|
1738
|
+
num_processes=num_processes,
|
|
1739
|
+
num_threads=num_threads,
|
|
1740
|
+
)
|
|
1741
|
+
|
|
1742
|
+
def stop_image_writer(self) -> None:
|
|
1743
|
+
"""
|
|
1744
|
+
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
|
1745
|
+
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
|
1746
|
+
"""
|
|
1747
|
+
if self.image_writer is not None:
|
|
1748
|
+
self.image_writer.stop()
|
|
1749
|
+
self.image_writer = None
|
|
1750
|
+
|
|
1751
|
+
def _wait_image_writer(self) -> None:
|
|
1752
|
+
"""Wait for asynchronous image writer to finish."""
|
|
1753
|
+
if self.image_writer is not None:
|
|
1754
|
+
self.image_writer.wait_until_done()
|
|
1755
|
+
|
|
1756
|
+
def encode_videos(self) -> None:
|
|
1757
|
+
"""
|
|
1758
|
+
Use ffmpeg to convert frames stored as png into mp4 videos.
|
|
1759
|
+
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
1760
|
+
since video encoding with ffmpeg is already using multithreading.
|
|
1761
|
+
"""
|
|
1762
|
+
for ep_idx in range(self.meta.total_episodes):
|
|
1763
|
+
self.encode_episode_videos(ep_idx)
|
|
1764
|
+
|
|
1765
|
+
def encode_episode_videos(self, episode_index: int) -> dict:
|
|
1766
|
+
"""
|
|
1767
|
+
Use ffmpeg to convert frames stored as png into mp4 videos.
|
|
1768
|
+
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
1769
|
+
since video encoding with ffmpeg is already using multithreading.
|
|
1770
|
+
"""
|
|
1771
|
+
video_paths = {}
|
|
1772
|
+
for key in self.meta.video_keys:
|
|
1773
|
+
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
|
1774
|
+
video_paths[key] = str(video_path)
|
|
1775
|
+
if video_path.is_file():
|
|
1776
|
+
# Skip if video is already encoded. Could be the case when resuming data recording.
|
|
1777
|
+
continue
|
|
1778
|
+
img_dir = self._get_image_file_path(
|
|
1779
|
+
episode_index=episode_index, image_key=key, frame_index=0
|
|
1780
|
+
).parent
|
|
1781
|
+
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
|
1782
|
+
|
|
1783
|
+
return video_paths
|
|
1784
|
+
|
|
1785
|
+
def _separate_image_in_time(self, item: dict):
|
|
1786
|
+
name_map = DATA_FEATURES_NAME_MAPPING[self._get_feature_mapping_key()]
|
|
1787
|
+
cam_keys = {v for k, v in name_map.items() if k.startswith("camera")}
|
|
1788
|
+
for k in cam_keys:
|
|
1789
|
+
images = item.pop(k)
|
|
1790
|
+
assert len(images) == 2, (
|
|
1791
|
+
f"{k} in {self.__class__} is expected to have length 2, got shape={images.shape}"
|
|
1792
|
+
)
|
|
1793
|
+
item[k + "_local"], item[k] = images
|
|
1794
|
+
|
|
1795
|
+
pads = item.pop(k + "_is_pad")
|
|
1796
|
+
assert len(pads) == 2, (
|
|
1797
|
+
f"{k} in {self.__class__} is expected to have length 2, got shape={pads.shape}"
|
|
1798
|
+
)
|
|
1799
|
+
item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
|
|
1800
|
+
|
|
1801
|
+
@staticmethod
|
|
1802
|
+
def compute_delta_params(
|
|
1803
|
+
mean: dict[str, np.ndarray | list[float]],
|
|
1804
|
+
std: dict[str, np.ndarray | list[float]],
|
|
1805
|
+
lower: dict[str, np.ndarray | list[float]],
|
|
1806
|
+
upper: dict[str, np.ndarray | list[float]],
|
|
1807
|
+
):
|
|
1808
|
+
r"""Process the parameters `mean`, `std`, `lower` and `upper` for delta timestamps.
|
|
1809
|
+
Delta timestamps will be computed dynamically in `__getitem__` with `clip(dT, lower, upper)` where `dT` follows
|
|
1810
|
+
the gaussian distribution N(mean, std^2). Each parameter is a dictionary mapping group names to sequences of
|
|
1811
|
+
floats.
|
|
1812
|
+
|
|
1813
|
+
For example, mean = {"group1": [-0.1, 0.0, 0.1], "group2": [0.0, 0.2]}. indicates that 3 delta timestamps for
|
|
1814
|
+
features in group1 will be sampled: time t-0.1, t, and t+0.1; and 2 delta timestamps for features in group2 will
|
|
1815
|
+
be sampled: time t and t+0.2, where t is the timestamp of the current data point.
|
|
1816
|
+
|
|
1817
|
+
It is assumed that the `std`, `lower`, and `upper` have the same keys as `mean`, and matching keys have values
|
|
1818
|
+
of the same length. If a key absent from `std`, `lower`, or `upper`, it will be set to a default value.
|
|
1819
|
+
Namely, `std` will be set to all 0, `lower` will be set to all `-inf`, and `upper` will be set to `+inf`, with
|
|
1820
|
+
lengths equal to the length of sequences in `mean` for that key.
|
|
1821
|
+
If a key is absent from `mean` but present in `std`, `lower`, or `upper`, it will be ignored.
|
|
1822
|
+
|
|
1823
|
+
After processing, the function returns four dictionaries: `mean`, `std`, `lower`, and `upper`, where each key
|
|
1824
|
+
is a feature name and each value is a numpy array of floats, satisfying the above conditions.
|
|
1825
|
+
"""
|
|
1826
|
+
inf = float("inf")
|
|
1827
|
+
mean = mean or {}
|
|
1828
|
+
mean = {k: np.array(v) for k, v in mean.items()}
|
|
1829
|
+
|
|
1830
|
+
std = std or {}
|
|
1831
|
+
std = {k: np.array(std.get(k) or np.zeros_like(v)) for k, v in mean.items()}
|
|
1832
|
+
|
|
1833
|
+
lower = lower or {}
|
|
1834
|
+
lower = {k: np.array(lower.get(k) or (np.zeros_like(v) - inf)) for k, v in mean.items()}
|
|
1835
|
+
|
|
1836
|
+
upper = upper or {}
|
|
1837
|
+
upper = {k: np.array(upper.get(k) or (np.zeros_like(v) + inf)) for k, v in mean.items()}
|
|
1838
|
+
|
|
1839
|
+
for k in mean:
|
|
1840
|
+
if not (mean[k].shape == std[k].shape == lower[k].shape == upper[k].shape):
|
|
1841
|
+
raise ValueError(
|
|
1842
|
+
f"Delta timestamps parameters for {k} have inconsistent shapes: "
|
|
1843
|
+
f"mean={mean[k].shape}, std={std[k].shape}, lower={lower[k].shape}, upper={upper[k].shape}"
|
|
1844
|
+
)
|
|
1845
|
+
|
|
1846
|
+
return mean, std, lower, upper
|
|
1847
|
+
|
|
1848
|
+
def _check_feature_group_mapping(self):
|
|
1849
|
+
for feature, (group, indices) in self.feature2group.items():
|
|
1850
|
+
if group not in self.delta_timestamps_params[0]:
|
|
1851
|
+
raise ValueError(
|
|
1852
|
+
f"Feature '{feature}' is mapped to group '{group}', which is not present in "
|
|
1853
|
+
"delta_timestamps_params. Please check the mapping."
|
|
1854
|
+
)
|
|
1855
|
+
if indices is not None and not isinstance(indices, (int, list)):
|
|
1856
|
+
raise ValueError(
|
|
1857
|
+
f"Indices for feature '{feature}' in group '{group}' should be a list, an int, or None"
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
@classmethod
|
|
1861
|
+
def create(
|
|
1862
|
+
cls,
|
|
1863
|
+
repo_id: str,
|
|
1864
|
+
fps: int,
|
|
1865
|
+
root: str | Path | None = None,
|
|
1866
|
+
robot_type: str | None = None,
|
|
1867
|
+
features: dict | None = None,
|
|
1868
|
+
use_videos: bool = True,
|
|
1869
|
+
tolerance_s: float = 1e-4,
|
|
1870
|
+
image_writer_processes: int = 0,
|
|
1871
|
+
image_writer_threads: int = 0,
|
|
1872
|
+
video_backend: str | None = None,
|
|
1873
|
+
image_resample_strategy: str = "nearest",
|
|
1874
|
+
vector_resample_strategy: str = "nearest",
|
|
1875
|
+
standardize: bool = True,
|
|
1876
|
+
) -> "LeRobotDataset":
|
|
1877
|
+
"""Create a LeRobot Dataset from scratch in order to record data."""
|
|
1878
|
+
obj = cls.__new__(cls)
|
|
1879
|
+
obj.meta = LeRobotDatasetMetadata.create(
|
|
1880
|
+
repo_id=repo_id,
|
|
1881
|
+
fps=fps,
|
|
1882
|
+
root=root,
|
|
1883
|
+
robot_type=robot_type,
|
|
1884
|
+
features=features,
|
|
1885
|
+
use_videos=use_videos,
|
|
1886
|
+
)
|
|
1887
|
+
obj.repo_id = obj.meta.repo_id
|
|
1888
|
+
obj.root = obj.meta.root
|
|
1889
|
+
obj.revision = None
|
|
1890
|
+
obj.tolerance_s = tolerance_s
|
|
1891
|
+
obj.image_writer = None
|
|
1892
|
+
|
|
1893
|
+
if image_writer_processes or image_writer_threads:
|
|
1894
|
+
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
|
1895
|
+
|
|
1896
|
+
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
|
1897
|
+
obj.episode_buffer = obj.create_episode_buffer()
|
|
1898
|
+
|
|
1899
|
+
obj.episodes = None
|
|
1900
|
+
obj.hf_dataset = obj.create_hf_dataset()
|
|
1901
|
+
obj.image_transforms = None
|
|
1902
|
+
obj.delta_timestamps_params = obj.compute_delta_params(None, None, None, None)
|
|
1903
|
+
obj.feature2group = {}
|
|
1904
|
+
obj.episode_data_index = None
|
|
1905
|
+
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
|
1906
|
+
obj.image_resample_strategy = image_resample_strategy
|
|
1907
|
+
obj.vector_resample_strategy = vector_resample_strategy
|
|
1908
|
+
obj.standardize = standardize
|
|
1909
|
+
obj.episode_data_index, obj.epi2idx = get_episode_data_index(obj.meta.episodes, obj.episodes)
|
|
1910
|
+
return obj
|