opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1910 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """LeRobot dataset implementation for robot learning data management.
18
+
19
+ This module provides the core dataset implementation for loading, creating, and
20
+ managing robot learning datasets. It supports both loading existing datasets from
21
+ the HuggingFace Hub or local disk, as well as creating new datasets for data
22
+ recording.
23
+
24
+ The dataset structure consists of:
25
+
26
+ - Metadata: Info, statistics, tasks, and episode information stored as JSON
27
+ - Data files: Episode data stored as Parquet files organized by chunks
28
+ - Videos: Optional video files for camera observations stored as MP4 files
29
+
30
+ Key Features:
31
+
32
+ - Temporal alignment: Supports delta timestamps for temporal feature
33
+ alignment, enabling sampling of features at different time offsets with
34
+ optional Gaussian noise for data augmentation.
35
+ - Multi-modal support: Handles images, videos, state vectors, actions, and
36
+ text prompts with automatic format conversion and standardization.
37
+ - Version compatibility: Automatic version checking and backward compatibility
38
+ handling for datasets created with older format versions.
39
+ - Asynchronous image writing: Optional async image writer for high-frequency
40
+ data recording without blocking the main process.
41
+ - Statistics management: Per-episode and aggregated statistics for data
42
+ normalization, with automatic computation and aggregation.
43
+ - Video handling: Supports multiple video backends (torchcodec, pyav,
44
+ video_reader) for efficient video encoding and decoding.
45
+
46
+ Classes:
47
+
48
+ DatasetMetadata
49
+ Base class for dataset metadata management.
50
+
51
+ LeRobotDatasetMetadata
52
+ Metadata manager for LeRobot datasets with Hub integration, version
53
+ checking, and statistics loading.
54
+
55
+ GroundingDatasetMetadata
56
+ Metadata manager for grounding datasets.
57
+
58
+ BaseDataset
59
+ Base PyTorch Dataset class with common functionality.
60
+
61
+ LeRobotDataset
62
+ Main dataset class for robot learning data, supporting loading from
63
+ Hub/local disk, temporal alignment, video/image handling, and data
64
+ recording.
65
+
66
+ Functions:
67
+ retry_random_on_failure
68
+ Decorator to retry dataset item retrieval with random indices on failure.
69
+
70
+ Example:
71
+ Load an existing dataset:
72
+ >>> dataset = LeRobotDataset(cfg, repo_id="my-robot-dataset")
73
+ >>> dataloader = DataLoader(dataset, batch_size=32)
74
+
75
+ Create a new dataset for recording:
76
+ >>> dataset = LeRobotDataset.create(
77
+ ... repo_id="my-new-dataset",
78
+ ... fps=30,
79
+ ... features={"state": {"shape": (7,), "dtype": "float32"}},
80
+ ... use_videos=True
81
+ ... )
82
+ """
83
+
84
+ import contextlib
85
+ import functools
86
+ import logging
87
+ import math
88
+ import shutil
89
+ import traceback
90
+ from abc import abstractmethod
91
+ from pathlib import Path
92
+ from typing import Callable
93
+
94
+ import datasets
95
+ import numpy as np
96
+ import packaging.version
97
+ import PIL.Image
98
+ import torch
99
+ import torch.nn.functional as F # noqa: N812
100
+ import torch.utils
101
+ from datasets import concatenate_datasets, load_dataset
102
+ from einops import rearrange
103
+ from huggingface_hub import HfApi, snapshot_download
104
+ from huggingface_hub.constants import REPOCARD_NAME
105
+ from huggingface_hub.errors import RevisionNotFoundError
106
+
107
+ from opentau.configs.train import TrainPipelineConfig
108
+ from opentau.constants import HF_OPENTAU_HOME
109
+ from opentau.datasets.compute_stats import aggregate_stats, compute_episode_stats
110
+ from opentau.datasets.image_writer import AsyncImageWriter, write_image
111
+ from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING, LOSS_TYPE_MAPPING
112
+ from opentau.datasets.utils import (
113
+ DEFAULT_FEATURES,
114
+ DEFAULT_IMAGE_PATH,
115
+ INFO_PATH,
116
+ TASKS_PATH,
117
+ append_jsonlines,
118
+ backward_compatible_episodes_stats,
119
+ check_timestamps_sync,
120
+ check_version_compatibility,
121
+ create_empty_dataset_info,
122
+ create_lerobot_dataset_card,
123
+ embed_images,
124
+ get_delta_indices_soft,
125
+ get_episode_data_index,
126
+ get_hf_features_from_features,
127
+ get_safe_version,
128
+ hf_transform_to_torch,
129
+ is_valid_version,
130
+ load_advantages,
131
+ load_episodes,
132
+ load_episodes_stats,
133
+ load_info,
134
+ load_stats,
135
+ load_tasks,
136
+ validate_episode_buffer,
137
+ validate_frame,
138
+ write_episode,
139
+ write_episode_stats,
140
+ write_info,
141
+ write_json,
142
+ )
143
+ from opentau.datasets.video_utils import (
144
+ decode_video_frames,
145
+ encode_video_frames,
146
+ get_safe_default_codec,
147
+ get_video_info,
148
+ )
149
+ from opentau.policies.value.configuration_value import ValueConfig
150
+ from opentau.policies.value.reward import (
151
+ calculate_return_bins_with_equal_width,
152
+ )
153
+ from opentau.utils.utils import on_accelerate_main_proc
154
+
155
+
156
+ def retry_random_on_failure(f):
157
+ """Decorator to retry dataset item retrieval with random indices on failure.
158
+
159
+ When a dataset item fails to load, this decorator will retry with random
160
+ indices up to `_total_rand_attempts` times before raising an error.
161
+
162
+ Args:
163
+ f: The `__getitem__` method to wrap.
164
+
165
+ Returns:
166
+ Wrapped function that retries on failure.
167
+ """
168
+
169
+ @functools.wraps(f)
170
+ def wrapped(self, idx):
171
+ g = getattr(self, "_rr_rng", None)
172
+ total_attempts = getattr(self, "_total_rand_attempts", 0)
173
+ if g is None:
174
+ g = torch.Generator()
175
+ g.manual_seed(torch.initial_seed()) # different seed per DataLoader worker
176
+ self._rr_rng = g
177
+
178
+ n = len(self)
179
+ cur = idx
180
+ exceptions = []
181
+ indices_tried = []
182
+ for _ in range(total_attempts + 1):
183
+ try:
184
+ indices_tried.append(cur)
185
+ return f(self, cur)
186
+ except Exception as e:
187
+ print(f"Encountered failure to load data at index {cur}; retrying with a different index.")
188
+ cur = int(torch.randint(0, n, (1,), generator=g))
189
+ exceptions.append(e)
190
+
191
+ tb_strings = [
192
+ f"Attempt {i}: trying to fetch index {item} ...\n"
193
+ + "".join(traceback.format_exception(type(e), e, e.__traceback__))
194
+ for i, (e, item) in enumerate(zip(exceptions, indices_tried, strict=False))
195
+ ]
196
+ tb_blob = "\n".join(tb_strings)
197
+ raise RuntimeError(
198
+ f"Failed to load data after {total_attempts + 1} attempt(s). "
199
+ "Check the following traceback for each attempts made.\n\n"
200
+ f"{tb_blob}"
201
+ )
202
+
203
+ return wrapped
204
+
205
+
206
+ CODEBASE_VERSION = "v2.1"
207
+
208
+
209
+ class DatasetMetadata:
210
+ """Base class for dataset metadata containing info and statistics.
211
+
212
+ Attributes:
213
+ info: Dictionary containing dataset information (features, fps, etc.).
214
+ stats: Dictionary containing dataset statistics for normalization.
215
+ repo_id: Repository ID of the dataset (set by subclasses).
216
+ """
217
+
218
+ def __init__(self, *, info: dict = None, stats: dict = None):
219
+ self.info = info or {"features": {}}
220
+ self.stats = stats or {}
221
+
222
+ for feature_name in self.stats:
223
+ for metric in self.stats[feature_name]:
224
+ if isinstance(self.stats[feature_name][metric], (list, tuple)):
225
+ self.stats[feature_name][metric] = np.array(self.stats[feature_name][metric])
226
+ # TODO: check stats[feature_name][metric].shape is broadcastable with features[feature_name]["shape"]
227
+
228
+ self.repo_id = None
229
+
230
+ @property
231
+ def features(self) -> dict[str, dict]:
232
+ """All features contained in the dataset."""
233
+ return self.info["features"]
234
+
235
+ @property
236
+ def image_keys(self) -> list[str]:
237
+ """Keys to access visual modalities stored as images."""
238
+ return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
239
+
240
+ @property
241
+ def video_keys(self) -> list[str]:
242
+ """Keys to access visual modalities stored as videos."""
243
+ return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
244
+
245
+ @property
246
+ def camera_keys(self) -> list[str]:
247
+ """Keys to access visual modalities (regardless of their storage method)."""
248
+ return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
249
+
250
+ @property
251
+ def names(self) -> dict[str, list | dict]:
252
+ """Names of the various dimensions of vector modalities."""
253
+ return {key: ft["names"] for key, ft in self.features.items()}
254
+
255
+ @property
256
+ def shapes(self) -> dict:
257
+ """Shapes for the different features."""
258
+ return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
259
+
260
+
261
+ class GroundingDatasetMetadata(DatasetMetadata):
262
+ """Metadata class for grounding datasets (vision-language datasets)."""
263
+
264
+ pass
265
+
266
+
267
+ class LeRobotDatasetMetadata(DatasetMetadata):
268
+ """Metadata manager for LeRobot datasets with Hub integration and version handling.
269
+
270
+ This class manages all metadata for LeRobot datasets, including dataset info,
271
+ statistics, episodes, tasks, and advantages. It handles loading from local disk
272
+ or HuggingFace Hub, version compatibility checking, and provides utilities for
273
+ accessing dataset files and information.
274
+
275
+ The class automatically handles:
276
+ - Loading metadata from local disk or downloading from HuggingFace Hub
277
+ - Version compatibility checking and automatic version resolution
278
+ - Backward compatibility with older dataset formats (v2.0 vs v2.1)
279
+ - Episode and task management
280
+ - Statistics aggregation (per-episode and global)
281
+
282
+ Attributes:
283
+ repo_id: Repository ID of the dataset on HuggingFace Hub.
284
+ root: Local root directory where the dataset is stored.
285
+ revision: Git revision (branch/tag/commit) of the dataset.
286
+ info: Dictionary containing dataset information (features, fps, paths, etc.).
287
+ stats: Aggregated statistics dictionary (mean, std, min, max, count).
288
+ episodes_stats: Per-episode statistics dictionary.
289
+ episodes: Dictionary mapping episode_index to episode information.
290
+ tasks: Dictionary mapping task_index to task descriptions.
291
+ task_to_task_index: Reverse mapping from task description to task_index.
292
+ advantages: Dictionary mapping (episode_index, timestamp) to advantage values.
293
+
294
+ Example:
295
+ Load metadata from Hub:
296
+ >>> meta = LeRobotDatasetMetadata("lerobot/aloha_mobile_cabinet")
297
+ >>> print(f"Total episodes: {meta.total_episodes}")
298
+
299
+ Create new dataset metadata:
300
+ >>> meta = LeRobotDatasetMetadata.create(
301
+ ... repo_id="my-dataset",
302
+ ... fps=30,
303
+ ... features={"state": {"dtype": "float32", "shape": (7,)}}
304
+ ... )
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ repo_id: str,
310
+ root: str | Path | None = None,
311
+ revision: str | None = None,
312
+ force_cache_sync: bool = False,
313
+ ):
314
+ super().__init__()
315
+ self.repo_id = repo_id
316
+ self.revision = revision if revision else CODEBASE_VERSION
317
+ self.root = Path(root) if root is not None else HF_OPENTAU_HOME / repo_id
318
+
319
+ try:
320
+ if force_cache_sync:
321
+ raise FileNotFoundError
322
+ self.load_metadata()
323
+ except (FileNotFoundError, NotADirectoryError):
324
+ if is_valid_version(self.revision):
325
+ self.revision = get_safe_version(self.repo_id, self.revision)
326
+
327
+ (self.root / "meta").mkdir(exist_ok=True, parents=True)
328
+ self.pull_from_repo(allow_patterns="meta/")
329
+ self.load_metadata()
330
+
331
+ def load_metadata(self) -> None:
332
+ """Load dataset metadata from disk.
333
+
334
+ Loads info, tasks, episodes, statistics, and advantages from the
335
+ dataset root directory. Handles version compatibility checks.
336
+ """
337
+ self.info = load_info(self.root)
338
+ check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
339
+ self.tasks, self.task_to_task_index = load_tasks(self.root)
340
+ self.episodes = load_episodes(self.root)
341
+ if self._version < packaging.version.parse("v2.1"):
342
+ self.stats = load_stats(self.root)
343
+ self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
344
+ else:
345
+ self.episodes_stats = load_episodes_stats(self.root)
346
+ self.stats = aggregate_stats(list(self.episodes_stats.values()))
347
+
348
+ self.advantages = load_advantages(self.root)
349
+
350
+ def pull_from_repo(
351
+ self,
352
+ allow_patterns: list[str] | str | None = None,
353
+ ignore_patterns: list[str] | str | None = None,
354
+ ) -> None:
355
+ snapshot_download(
356
+ self.repo_id,
357
+ repo_type="dataset",
358
+ revision=self.revision,
359
+ local_dir=self.root,
360
+ allow_patterns=allow_patterns,
361
+ ignore_patterns=ignore_patterns,
362
+ )
363
+
364
+ @property
365
+ def _version(self) -> packaging.version.Version:
366
+ """Codebase version used to create this dataset."""
367
+ return packaging.version.parse(self.info["codebase_version"])
368
+
369
+ def get_data_file_path(self, ep_index: int) -> Path:
370
+ """Get the file path for a specific episode's parquet data file.
371
+
372
+ Args:
373
+ ep_index: Episode index.
374
+
375
+ Returns:
376
+ Path to the parquet file for the episode.
377
+ """
378
+ ep_chunk = self.get_episode_chunk(ep_index)
379
+ fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
380
+ return Path(fpath)
381
+
382
+ def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
383
+ """Get the file path for a specific episode's video file.
384
+
385
+ Args:
386
+ ep_index: Episode index.
387
+ vid_key: Video key/name (e.g., "camera0").
388
+
389
+ Returns:
390
+ Path to the video file for the episode.
391
+ """
392
+ ep_chunk = self.get_episode_chunk(ep_index)
393
+ fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
394
+ return Path(fpath)
395
+
396
+ def get_episode_chunk(self, ep_index: int) -> int:
397
+ """Get the chunk index for a given episode index.
398
+
399
+ Episodes are grouped into chunks for efficient storage.
400
+
401
+ Args:
402
+ ep_index: Episode index.
403
+
404
+ Returns:
405
+ Chunk index containing this episode.
406
+ """
407
+ return ep_index // self.chunks_size
408
+
409
+ @property
410
+ def data_path(self) -> str:
411
+ """Formattable string for the parquet files."""
412
+ return self.info["data_path"]
413
+
414
+ @property
415
+ def video_path(self) -> str | None:
416
+ """Formattable string for the video files."""
417
+ return self.info["video_path"]
418
+
419
+ @property
420
+ def robot_type(self) -> str | None:
421
+ """Robot type used in recording this dataset."""
422
+ return self.info["robot_type"]
423
+
424
+ @property
425
+ def fps(self) -> int:
426
+ """Frames per second used during data collection."""
427
+ return self.info["fps"]
428
+
429
+ @property
430
+ def total_episodes(self) -> int:
431
+ """Total number of episodes available."""
432
+ return self.info["total_episodes"]
433
+
434
+ @property
435
+ def total_frames(self) -> int:
436
+ """Total number of frames saved in this dataset."""
437
+ return self.info["total_frames"]
438
+
439
+ @property
440
+ def total_tasks(self) -> int:
441
+ """Total number of different tasks performed in this dataset."""
442
+ return self.info["total_tasks"]
443
+
444
+ @property
445
+ def total_chunks(self) -> int:
446
+ """Total number of chunks (groups of episodes)."""
447
+ return self.info["total_chunks"]
448
+
449
+ @property
450
+ def chunks_size(self) -> int:
451
+ """Max number of episodes per chunk."""
452
+ return self.info["chunks_size"]
453
+
454
+ def get_task_index(self, task: str) -> int | None:
455
+ """
456
+ Given a task in natural language, returns its task_index if the task already exists in the dataset,
457
+ otherwise return None.
458
+ """
459
+ return self.task_to_task_index.get(task, None)
460
+
461
+ def add_task(self, task: str):
462
+ """
463
+ Given a task in natural language, add it to the dictionary of tasks.
464
+ """
465
+ if task in self.task_to_task_index:
466
+ raise ValueError(f"The task '{task}' already exists and can't be added twice.")
467
+
468
+ task_index = self.info["total_tasks"]
469
+ self.task_to_task_index[task] = task_index
470
+ self.tasks[task_index] = task
471
+ self.info["total_tasks"] += 1
472
+
473
+ task_dict = {
474
+ "task_index": task_index,
475
+ "task": task,
476
+ }
477
+ append_jsonlines(task_dict, self.root / TASKS_PATH)
478
+
479
+ def save_episode(
480
+ self,
481
+ episode_index: int,
482
+ episode_length: int,
483
+ episode_tasks: list[str],
484
+ episode_stats: dict[str, dict],
485
+ ) -> None:
486
+ self.info["total_episodes"] += 1
487
+ self.info["total_frames"] += episode_length
488
+
489
+ chunk = self.get_episode_chunk(episode_index)
490
+ if chunk >= self.total_chunks:
491
+ self.info["total_chunks"] += 1
492
+
493
+ self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
494
+ self.info["total_videos"] += len(self.video_keys)
495
+ if len(self.video_keys) > 0:
496
+ self.update_video_info()
497
+
498
+ write_info(self.info, self.root)
499
+
500
+ episode_dict = {
501
+ "episode_index": episode_index,
502
+ "tasks": episode_tasks,
503
+ "length": episode_length,
504
+ }
505
+ self.episodes[episode_index] = episode_dict
506
+ write_episode(episode_dict, self.root)
507
+
508
+ self.episodes_stats[episode_index] = episode_stats
509
+ self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
510
+ write_episode_stats(episode_index, episode_stats, self.root)
511
+
512
+ def update_video_info(self) -> None:
513
+ """
514
+ Warning: this function writes info from first episode videos, implicitly assuming that all videos have
515
+ been encoded the same way. Also, this means it assumes the first episode exists.
516
+ """
517
+ for key in self.video_keys:
518
+ if not self.features[key].get("info", None):
519
+ video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
520
+ self.info["features"][key]["info"] = get_video_info(video_path)
521
+
522
+ def __repr__(self):
523
+ feature_keys = list(self.features)
524
+ return (
525
+ f"{self.__class__.__name__}({{\n"
526
+ f" Repository ID: '{self.repo_id}',\n"
527
+ f" Total episodes: '{self.total_episodes}',\n"
528
+ f" Total frames: '{self.total_frames}',\n"
529
+ f" Features: '{feature_keys}',\n"
530
+ "})',\n"
531
+ )
532
+
533
+ @classmethod
534
+ def create(
535
+ cls,
536
+ repo_id: str,
537
+ fps: int,
538
+ root: str | Path | None = None,
539
+ robot_type: str | None = None,
540
+ features: dict | None = None,
541
+ use_videos: bool = True,
542
+ ) -> "LeRobotDatasetMetadata":
543
+ """Creates metadata for a LeRobotDataset."""
544
+ obj = cls.__new__(cls)
545
+ obj.repo_id = repo_id
546
+ obj.root = Path(root) if root is not None else HF_OPENTAU_HOME / repo_id
547
+
548
+ obj.root.mkdir(parents=True, exist_ok=False)
549
+
550
+ if features is None:
551
+ raise ValueError("Dataset features must be explicitly passed upon creation.")
552
+ else:
553
+ # TODO(aliberts, rcadene): implement sanity check for features
554
+ features = {**features, **DEFAULT_FEATURES}
555
+
556
+ # check if none of the features contains a "/" in their names,
557
+ # as this would break the dict flattening in the stats computation, which uses '/' as separator
558
+ for key in features:
559
+ if "/" in key:
560
+ raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
561
+
562
+ features = {**features, **DEFAULT_FEATURES}
563
+
564
+ obj.tasks, obj.task_to_task_index = {}, {}
565
+ obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
566
+ obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
567
+ if len(obj.video_keys) > 0 and not use_videos:
568
+ raise ValueError()
569
+ write_json(obj.info, obj.root / INFO_PATH)
570
+ obj.revision = None
571
+ return obj
572
+
573
+
574
+ class BaseDataset(torch.utils.data.Dataset):
575
+ """Base class for all robot learning datasets.
576
+
577
+ This abstract base class provides common functionality for both LeRobotDataset
578
+ and GroundingDataset, including data format standardization, image processing,
579
+ and vector padding. It ensures all datasets conform to a standard format
580
+ regardless of their source or structure.
581
+
582
+ Key Features:
583
+ - Standard data format conversion: Maps dataset-specific feature names
584
+ to standard names (camera0, camera1, state, actions, etc.)
585
+ - Image standardization: Resizes and pads images to target resolution
586
+ while maintaining aspect ratio
587
+ - Vector padding: Pads state and action vectors to maximum dimensions
588
+ - Data type conversion: Converts floating-point tensors to bfloat16 for
589
+ memory efficiency
590
+ - String normalization: Ensures prompts and responses have consistent
591
+ newline formatting
592
+
593
+ Subclasses must implement:
594
+ - `_get_feature_mapping_key()`: Returns the key used for feature name
595
+ mapping (e.g., "lerobot/aloha_mobile_cabinet")
596
+ - `_separate_image_in_time()`: Separates temporal image sequences into
597
+ individual frames
598
+
599
+ Attributes:
600
+ resolution: Target image resolution (height, width).
601
+ num_cams: Number of camera views in each sample.
602
+ max_state_dim: Maximum dimension for state vectors.
603
+ max_action_dim: Maximum dimension for action vectors.
604
+ action_chunk: Number of actions processed in a chunk.
605
+
606
+ Example:
607
+ Create a custom dataset:
608
+ >>> class MyDataset(BaseDataset):
609
+ ... def _get_feature_mapping_key(self):
610
+ ... return "my-dataset"
611
+ ... def _separate_image_in_time(self, item):
612
+ ... pass # No temporal separation needed
613
+ """
614
+
615
+ def __init__(self, cfg: TrainPipelineConfig):
616
+ super().__init__()
617
+ # Standard Data Format parameters
618
+ self.resolution = cfg.resolution # resolution of images (H, W) in data sample
619
+ self.num_cams = cfg.num_cams # number of cameras in each data sample
620
+ self.max_state_dim = cfg.max_state_dim # maximum dimension of the state vector
621
+ self.max_action_dim = cfg.max_action_dim # maximum dimension of the action vector
622
+ self.action_chunk = cfg.action_chunk # number of actions to be processed in a chunk
623
+
624
+ @abstractmethod
625
+ def _get_feature_mapping_key(self) -> str:
626
+ r"""Returns the key used for feature mapping"""
627
+ pass
628
+
629
+ @abstractmethod
630
+ def _separate_image_in_time(self, item: dict):
631
+ r"""Some keys correspond to 2 images, where the first is image at current timestamp and the second is the image
632
+ from some time ago. We separate these 2 images into different keys by modifying the `item` dictionary.
633
+ For example, {"image_key": torch.zeros(2, 3, 224, 224), "image_key_is_pad": [False, True] } will become
634
+ {
635
+ "image_key": torch.zeros(3, 224, 224),
636
+ "image_key_is_pad: False,
637
+ }.
638
+ """
639
+ raise NotImplementedError
640
+
641
+ def _standardize_images(self, item, standard_item, n_cams, is_local) -> list[bool]:
642
+ """Standardize image features to a common format.
643
+
644
+ Resizes images to the target resolution with padding, and tracks
645
+ which images are padded.
646
+
647
+ Args:
648
+ item: Input item dictionary with original image keys.
649
+ standard_item: Output dictionary to populate with standardized images.
650
+ n_cams: Number of cameras to process.
651
+ is_local: Whether processing local (past) images.
652
+
653
+ Returns:
654
+ List of boolean values indicating which images are padded.
655
+ """
656
+ name_map = DATA_FEATURES_NAME_MAPPING[self._get_feature_mapping_key()]
657
+ image_is_pad = []
658
+ for cam_idx in range(n_cams):
659
+ std_key = f"camera{cam_idx}"
660
+ key = name_map.get(std_key)
661
+
662
+ if key is None:
663
+ standard_item[std_key] = torch.zeros((3, *self.resolution))
664
+ image_is_pad.append(True)
665
+ else:
666
+ standard_item[std_key] = self.resize_with_pad(
667
+ item[key],
668
+ self.resolution[1],
669
+ self.resolution[0],
670
+ pad_value=0,
671
+ )
672
+ image_is_pad.append(item.get(key + "_is_pad", torch.tensor(False)).item())
673
+ assert (
674
+ len(standard_item[std_key].shape) == 3
675
+ and standard_item[std_key].shape[0] == 3
676
+ and standard_item[std_key].min() >= 0.0 - 1e-6 # bfloat16 results in precision loss
677
+ and standard_item[std_key].max() <= 1.0 + 1e-6 # bfloat16 results in precision loss
678
+ ), (
679
+ f"Expected image {std_key} to have shape (3, H, W) with values in [0, 1], "
680
+ f"Got shape {standard_item[std_key].shape}, "
681
+ f"min={standard_item[std_key].min()}, "
682
+ f"max={standard_item[std_key].max()}, "
683
+ f"self={self._get_feature_mapping_key()}."
684
+ )
685
+
686
+ return image_is_pad
687
+
688
+ def _to_standard_data_format(self, item: dict) -> dict:
689
+ """Convert dataset item to standard data format.
690
+
691
+ Standardizes feature names, separates images in time, pads vectors,
692
+ and ensures consistent data types and formats.
693
+
694
+ Args:
695
+ item: Raw dataset item dictionary.
696
+
697
+ Returns:
698
+ Dictionary with standardized feature names and formats.
699
+ """
700
+ name_map = DATA_FEATURES_NAME_MAPPING[self._get_feature_mapping_key()]
701
+ self._separate_image_in_time(item)
702
+
703
+ standard_item = {}
704
+ img_is_pad = self._standardize_images(item, standard_item, self.num_cams, False)
705
+
706
+ for new_key, key in name_map.items():
707
+ if new_key.startswith("camera"):
708
+ continue
709
+ standard_item[new_key] = item[key]
710
+
711
+ # pad state and action vectors
712
+ standard_item["state"] = self.pad_vector(standard_item["state"], self.max_state_dim)
713
+ standard_item["actions"] = self.pad_vector(standard_item["actions"], self.max_action_dim)
714
+
715
+ standard_item["img_is_pad"] = torch.tensor(img_is_pad, dtype=torch.bool)
716
+ standard_item["action_is_pad"] = item[name_map["actions"] + "_is_pad"]
717
+
718
+ # add loss type
719
+ standard_item["loss_type"] = LOSS_TYPE_MAPPING[self._get_feature_mapping_key()]
720
+
721
+ # cast all tensors in standard_item to bfloat16
722
+ for key, value in standard_item.items():
723
+ if isinstance(value, torch.Tensor) and value.dtype.is_floating_point:
724
+ standard_item[key] = value.to(dtype=torch.bfloat16)
725
+
726
+ # ensure that non-empty strings contain exactly one newline character at the end of the string
727
+ for key in ["prompt", "response"]:
728
+ if standard_item[key].endswith(
729
+ "\n"
730
+ ): # ensure there isn't going to be an extra space at the end after calling replace
731
+ standard_item[key] = standard_item[key][:-1]
732
+ standard_item[key] = standard_item[key].replace("\n", " ") + "\n"
733
+
734
+ return standard_item
735
+
736
+ def resize_with_pad(self, img, width, height, pad_value=0) -> torch.Tensor:
737
+ """Resize an image to target dimensions with padding.
738
+
739
+ Maintains aspect ratio by resizing to fit within target dimensions,
740
+ then pads on the left and top to reach exact target size.
741
+
742
+ Args:
743
+ img: Input image tensor of shape (C, H, W).
744
+ width: Target width.
745
+ height: Target height.
746
+ pad_value: Value to use for padding. Defaults to 0.
747
+
748
+ Returns:
749
+ Resized and padded image tensor of shape (C, height, width).
750
+
751
+ Raises:
752
+ ValueError: If input image doesn't have 4 dimensions when reshaped.
753
+ """
754
+ # assume no-op when width height fits already
755
+ img = rearrange(img, "c h w -> 1 c h w")
756
+ if img.ndim != 4:
757
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
758
+
759
+ cur_height, cur_width = img.shape[2:]
760
+
761
+ ratio = max(cur_width / width, cur_height / height)
762
+ resized_height = int(cur_height / ratio)
763
+ resized_width = int(cur_width / ratio)
764
+ resized_img = F.interpolate(
765
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
766
+ )
767
+
768
+ pad_height = max(0, int(height - resized_height))
769
+ pad_width = max(0, int(width - resized_width))
770
+
771
+ # pad on left and top of image
772
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
773
+
774
+ # rearrange back to (c, h, w)
775
+ padded_img = rearrange(padded_img, "1 c h w -> c h w")
776
+ return padded_img
777
+
778
+ @staticmethod
779
+ def pad_vector(vector, new_dim):
780
+ """Only the last dimension of the vector is padded to 'new_dim' with zeros."""
781
+ if vector.shape[-1] == new_dim:
782
+ return vector
783
+ shape = list(vector.shape)
784
+ current_dim = shape[-1]
785
+ shape[-1] = new_dim
786
+ new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
787
+ new_vector[..., :current_dim] = vector
788
+ return new_vector
789
+
790
+
791
+ class LeRobotDataset(BaseDataset):
792
+ """Main dataset class for loading and managing robot learning data.
793
+
794
+ This class provides a PyTorch Dataset interface for robot learning datasets
795
+ stored in the LeRobot format. It supports loading from HuggingFace Hub or
796
+ local disk, handles temporal alignment with delta timestamps, manages video
797
+ and image data, and provides data recording capabilities.
798
+
799
+ The dataset structure consists of:
800
+ - Metadata: JSON files containing info, statistics, episodes, tasks
801
+ - Data files: Parquet files organized by chunks containing episode data
802
+ - Videos: Optional MP4 files for camera observations
803
+
804
+ Key Features:
805
+ - Hub integration: Automatic download from HuggingFace Hub with version
806
+ compatibility checking
807
+ - Temporal alignment: Delta timestamps enable sampling features at
808
+ different time offsets with optional Gaussian noise for augmentation
809
+ - Video/image handling: Supports both video files and individual images
810
+ with automatic frame extraction and synchronization
811
+ - Episode filtering: Load specific episodes by index
812
+ - Data recording: Create new datasets and add episodes programmatically
813
+ - Statistics: Per-episode and aggregated statistics for normalization
814
+
815
+ Two Usage Modes:
816
+ 1. Loading existing datasets: From local disk or HuggingFace Hub
817
+ 2. Creating new datasets: Using the `create()` classmethod for data
818
+ recording
819
+
820
+ Attributes:
821
+ cfg: Training pipeline configuration.
822
+ repo_id: Repository ID of the dataset.
823
+ root: Local root directory for the dataset.
824
+ meta: LeRobotDatasetMetadata instance containing all metadata.
825
+ hf_dataset: HuggingFace Dataset containing parquet data.
826
+ episodes: Dictionary mapping episode_index to episode info.
827
+ image_transforms: Optional image transforms to apply.
828
+ delta_timestamps_params: Processed delta timestamp parameters.
829
+ feature2group: Mapping from features to temporal groups.
830
+ video_backend: Backend used for video decoding.
831
+ standardize: Whether to standardize data format.
832
+
833
+ Example:
834
+ Load dataset from Hub:
835
+ >>> dataset = LeRobotDataset(cfg, repo_id="lerobot/aloha")
836
+ >>> dataloader = DataLoader(dataset, batch_size=32)
837
+
838
+ Load specific episodes:
839
+ >>> dataset = LeRobotDataset(
840
+ ... cfg,
841
+ ... repo_id="lerobot/aloha",
842
+ ... episodes=[0, 1, 2, 5, 10]
843
+ ... )
844
+
845
+ Create new dataset for recording:
846
+ >>> dataset = LeRobotDataset.create(
847
+ ... cfg,
848
+ ... repo_id="my-new-dataset",
849
+ ... fps=30,
850
+ ... features={"state": {"dtype": "float32", "shape": (7,)}}
851
+ ... )
852
+ """
853
+
854
+ def __init__(
855
+ self,
856
+ cfg: TrainPipelineConfig,
857
+ repo_id: str,
858
+ root: str | Path | None = None,
859
+ episodes: list[int] | None = None,
860
+ image_transforms: Callable | None = None,
861
+ delta_timestamps: dict[str, np.ndarray | list[float]] | None = None,
862
+ delta_timestamps_std: dict[str, np.ndarray | list[float]] | None = None,
863
+ delta_timestamps_lower: dict[str, np.ndarray | list[float]] | None = None,
864
+ delta_timestamps_upper: dict[str, np.ndarray | list[float]] | None = None,
865
+ feature2group: dict[str, tuple[str, (list[int] | int | None)]] | None = None,
866
+ tolerance_s: float = 1e-4,
867
+ revision: str | None = None,
868
+ force_cache_sync: bool = False,
869
+ download_videos: bool = True,
870
+ video_backend: str | None = None,
871
+ image_resample_strategy: str = "nearest",
872
+ vector_resample_strategy: str = "nearest",
873
+ standardize: bool = True,
874
+ return_advantage_input: bool = False,
875
+ ):
876
+ """Initialize LeRobotDataset.
877
+
878
+ 2 modes are available for instantiating this class, depending on 2 different use cases:
879
+
880
+ 1. Your dataset already exists:
881
+
882
+ - On your local disk in the 'root' folder. This is typically the case when you recorded your
883
+ dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
884
+ with 'root' will load your dataset directly from disk. This can happen while you're offline (no
885
+ internet connection).
886
+
887
+ - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
888
+ your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
889
+ the dataset from that address and load it, pending your dataset is compliant with
890
+ codebase_version v2.0. If your dataset has been created before this new format, you will be
891
+ prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at
892
+ lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
893
+
894
+ 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
895
+ LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an
896
+ existing dataset to the LeRobotDataset format.
897
+
898
+ In terms of files, LeRobotDataset encapsulates 3 main things:
899
+
900
+ - metadata:
901
+
902
+ - info contains various information about the dataset like shapes, keys, fps etc.
903
+ - stats stores the dataset statistics of the different modalities for normalization
904
+ - tasks contains the prompts for each task of the dataset, which can be used for
905
+ task-conditioned training.
906
+
907
+ - hf_dataset (from datasets.Dataset), which will read any values from parquet files.
908
+
909
+ - videos (optional) from which frames are loaded to be synchronous with data from parquet files.
910
+
911
+ A typical LeRobotDataset looks like this from its root path::
912
+
913
+ .
914
+ ├── data
915
+ │ ├── chunk-000
916
+ │ │ ├── episode_000000.parquet
917
+ │ │ ├── episode_000001.parquet
918
+ │ │ ├── episode_000002.parquet
919
+ │ │ └── ...
920
+ │ ├── chunk-001
921
+ │ │ ├── episode_001000.parquet
922
+ │ │ ├── episode_001001.parquet
923
+ │ │ ├── episode_001002.parquet
924
+ │ │ └── ...
925
+ │ └── ...
926
+ ├── meta
927
+ │ ├── episodes.jsonl
928
+ │ ├── info.json
929
+ │ ├── stats.json
930
+ │ └── tasks.jsonl
931
+ └── videos
932
+ ├── chunk-000
933
+ │ ├── observation.images.laptop
934
+ │ │ ├── episode_000000.mp4
935
+ │ │ ├── episode_000001.mp4
936
+ │ │ ├── episode_000002.mp4
937
+ │ │ └── ...
938
+ │ ├── observation.images.phone
939
+ │ │ ├── episode_000000.mp4
940
+ │ │ ├── episode_000001.mp4
941
+ │ │ ├── episode_000002.mp4
942
+ │ │ └── ...
943
+ ├── chunk-001
944
+ └── ...
945
+
946
+ Note that this file-based structure is designed to be as versatile as possible. The files are split by
947
+ episodes which allows a more granular control over which episodes one wants to use and download. The
948
+ structure of the dataset is entirely described in the info.json file, which can be easily downloaded
949
+ or viewed directly on the hub before downloading any actual data. The type of files used are very
950
+ simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
951
+ for the README).
952
+
953
+ Args:
954
+ cfg (TrainPipelineConfig): Training configuration object.
955
+ repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
956
+ will be stored under root/repo_id.
957
+ root (Path | None, optional): Local directory to use for downloading/writing files. You can also
958
+ set the HF_OPENTAU_HOME environment variable to point to a different location. Defaults to
959
+ '~/.cache/huggingface/opentau'.
960
+ episodes (list[int] | None, optional): If specified, this will only load episodes specified by
961
+ their episode_index in this list. Defaults to None.
962
+ image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
963
+ torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
964
+ from videos or images). Defaults to None.
965
+ delta_timestamps (dict[list[float]] | None, optional): Dictionary where each key is a group name and its
966
+ corresponding value is a list of delta timestamps in seconds. For example, {'group1': [0, 0.1]} means
967
+ features of group1 will be returned as a chunk of 2, with the first element being the value at current
968
+ time and the second element being the value at current time + 0.1 seconds. This will also add a key
969
+ named '{feature}_is_pad' to the returned item, with a boolean type and a length of 2, indicating whether
970
+ the feature is padded or not. Padding will happen when t + 0.1 is outside the episode time range.
971
+ Defaults to None.
972
+ delta_timestamps_std: (dict[list[float]] | None, optional): Similar to delta_timestamps, but specifies an
973
+ optional standard deviation for the delta timestamps. If a key is absent, the delta timestamps for that
974
+ key will be deterministic. If a key is present without corresponding delta_timestamps, it will be
975
+ ignored. E.g., delta_timestamps={'group1': [0, 0.1]} and delta_timestamps_std={'group1': [0, 0.05]} will
976
+ result in a chunk of 2, with the first element being the feature at current time and the second
977
+ element at a time following a Gaussian distribution with N(t+0.1, 0.05^2). When it takes on a value
978
+ outside the episode, the corresponding element in `{feature}_is_mask` will be set to True.
979
+ Defaults to None.
980
+ delta_timestamps_lower: (dict[list[float]] | None, optional): Similar to delta_timestamps_std, but specifies
981
+ a minimum value for the delta timestamps. When specified, the delta timestamps will be lower-clipped
982
+ accordingly. Defaults to None.
983
+ delta_timestamps_upper: (dict[list[float]] | None, optional): Similar to delta_timestamps_std, but specifies
984
+ a maximum value for the delta timestamps. When specified, the delta timestamps will be upper-clipped
985
+ accordingly. Defaults to None.
986
+ feature2group: (dict[str, tuple[str, (list[int] | int | None)]] | None, optional): Dictionary mapping every
987
+ individual feature to a tuple of (group name, indices). Group names are keys passed to delta_timestamps.
988
+ If `indices` is None, will use all indices in the group. If indices is a list, will use only those
989
+ indices in the corresponding order, including duplicates if present. If indices is an int, will return
990
+ that index only, resulting in a reduction in ndim by 1.
991
+ For example, `feature2group={'action': ('group1', None), 'observation.state': ('group2', 0),
992
+ 'observation.images.left_hand': ('group2', [0, 1])}` means the feature `action` will use resolved
993
+ `delta_timestamps` from `group1` and will return every index. Also, `observation.state` will pick the
994
+ first element (index-0) of `group2`. `observation.images.left_hand` will pick the first and second
995
+ elements (indices 0 and 1) of `group2` and return them as a chunk of 2 images. The first element of
996
+ `observation.images.left_hand` and the state vector will always be sampled at the same timestamp despite
997
+ having gaussian noise applied, because they are in the same group.
998
+ tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
999
+ sync with the fps value. It is used at the init of the dataset to make sure that each
1000
+ timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
1001
+ decoded from video files. Defaults to 1e-4.
1002
+ revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
1003
+ commit hash. Defaults to current codebase version tag.
1004
+ download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
1005
+ video files are already present on local disk, they won't be downloaded again. Defaults to
1006
+ True.
1007
+ video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
1008
+ You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
1009
+ image_resample_strategy: str: Resampling strategy to use for image features.
1010
+ If 'linear', it will use linear interpolation between two immediate timestamps.
1011
+ If 'nearest', it will use nearest neighbor interpolation.
1012
+ Defaults to 'nearest'.
1013
+ vector_resample_strategy: str: Resampling strategy to use for non-image features, such as action or state.
1014
+ If 'linear', it will use linear interpolation between two immediate timestamps.
1015
+ If 'nearest', it will use nearest neighbor interpolation.
1016
+ Defaults to 'nearest'.
1017
+ standardize (bool, Optional): Flag to enable standardization in `__getitem__`. Defaults to True.
1018
+ return_advantage_input (bool, Optional): Flag to return advantage inputs ("success", "episode_end_idx", "current_idx", "last_step", "episode_index", "timestamp", ). Defaults to False. Ignored if standardize is False.
1019
+ """
1020
+ super().__init__(cfg)
1021
+ self.cfg = cfg
1022
+ self.repo_id = repo_id
1023
+ self.root = Path(root) if root else HF_OPENTAU_HOME / repo_id
1024
+ self.image_transforms = image_transforms
1025
+ if bool(delta_timestamps) ^ bool(feature2group):
1026
+ raise ValueError(
1027
+ "Either both delta_timestamps and feature2group should be provided, or neither of them."
1028
+ )
1029
+ # delta_timestamps_params is a 4 tuple (mean, std, lower, upper)
1030
+ self.delta_timestamps_params = self.compute_delta_params(
1031
+ delta_timestamps,
1032
+ delta_timestamps_std,
1033
+ delta_timestamps_lower,
1034
+ delta_timestamps_upper,
1035
+ )
1036
+ self.feature2group = feature2group or {}
1037
+ self._check_feature_group_mapping()
1038
+ self.episodes = episodes
1039
+ self.tolerance_s = tolerance_s
1040
+ self.revision = revision if revision else CODEBASE_VERSION
1041
+ self.video_backend = video_backend if video_backend else get_safe_default_codec()
1042
+
1043
+ if image_resample_strategy not in ["linear", "nearest"]:
1044
+ raise ValueError(
1045
+ f"Invalid image resample strategy: {image_resample_strategy}. Choose 'linear' or 'nearest'."
1046
+ )
1047
+ if vector_resample_strategy not in ["linear", "nearest"]:
1048
+ raise ValueError(
1049
+ f"Invalid action resample strategy: {vector_resample_strategy}. Choose 'linear' or 'nearest'."
1050
+ )
1051
+ self.image_resample_strategy = image_resample_strategy
1052
+ self.vector_resample_strategy = vector_resample_strategy
1053
+
1054
+ self.standardize = standardize
1055
+ if return_advantage_input and not standardize:
1056
+ print(
1057
+ "Warning: `return_advantage_input` is True while `standardize` is False. "
1058
+ "No advantage inputs will be returned."
1059
+ )
1060
+ self.return_advantage_input = return_advantage_input
1061
+
1062
+ # Unused attributes
1063
+ self.image_writer = None
1064
+ self.episode_buffer = None
1065
+
1066
+ self.root.mkdir(exist_ok=True, parents=True)
1067
+
1068
+ # Load metadata
1069
+ self.meta = LeRobotDatasetMetadata(
1070
+ self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
1071
+ )
1072
+ if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
1073
+ episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
1074
+ self.stats = aggregate_stats(episodes_stats)
1075
+
1076
+ if self.episodes is None:
1077
+ self.episodes = list(self.meta.episodes)
1078
+
1079
+ # Load actual data
1080
+ try:
1081
+ if force_cache_sync:
1082
+ raise FileNotFoundError
1083
+ assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
1084
+ self.hf_dataset = self.load_hf_dataset()
1085
+ except (AssertionError, FileNotFoundError, NotADirectoryError):
1086
+ self.revision = get_safe_version(self.repo_id, self.revision)
1087
+ self.download_episodes(download_videos)
1088
+ self.hf_dataset = self.load_hf_dataset()
1089
+
1090
+ self.episode_data_index, self.epi2idx = get_episode_data_index(self.meta.episodes, self.episodes)
1091
+
1092
+ # Check timestamps
1093
+ # If transform is set, with_transform will decode all columns of a row before returning the desired column(s).
1094
+ no_transform_ds = self.hf_dataset.with_transform(None).with_format("numpy")
1095
+ logging.info("Checking timestamps synchronization...")
1096
+ timestamps = np.asarray(no_transform_ds["timestamp"], dtype=np.float32)
1097
+ episode_indices = np.asarray(no_transform_ds["episode_index"], dtype=np.int64)
1098
+ ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
1099
+ check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
1100
+
1101
+ @on_accelerate_main_proc(local=True, _sync=True)
1102
+ def push_to_hub(
1103
+ self,
1104
+ branch: str | None = None,
1105
+ tags: list | None = None,
1106
+ license: str | None = "apache-2.0",
1107
+ tag_version: bool = True,
1108
+ push_videos: bool = True,
1109
+ private: bool = False,
1110
+ allow_patterns: list[str] | str | None = None,
1111
+ upload_large_folder: bool = False,
1112
+ **card_kwargs,
1113
+ ) -> None:
1114
+ ignore_patterns = ["images/"]
1115
+ if not push_videos:
1116
+ ignore_patterns.append("videos/")
1117
+
1118
+ hub_api = HfApi()
1119
+ hub_api.create_repo(
1120
+ repo_id=self.repo_id,
1121
+ private=private,
1122
+ repo_type="dataset",
1123
+ exist_ok=True,
1124
+ )
1125
+ if branch:
1126
+ hub_api.create_branch(
1127
+ repo_id=self.repo_id,
1128
+ branch=branch,
1129
+ revision=self.revision,
1130
+ repo_type="dataset",
1131
+ exist_ok=True,
1132
+ )
1133
+
1134
+ upload_kwargs = {
1135
+ "repo_id": self.repo_id,
1136
+ "folder_path": self.root,
1137
+ "repo_type": "dataset",
1138
+ "revision": branch,
1139
+ "allow_patterns": allow_patterns,
1140
+ "ignore_patterns": ignore_patterns,
1141
+ }
1142
+ if upload_large_folder:
1143
+ hub_api.upload_large_folder(**upload_kwargs)
1144
+ else:
1145
+ hub_api.upload_folder(**upload_kwargs)
1146
+
1147
+ if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
1148
+ card = create_lerobot_dataset_card(
1149
+ tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
1150
+ )
1151
+ card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
1152
+
1153
+ if tag_version:
1154
+ with contextlib.suppress(RevisionNotFoundError):
1155
+ hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
1156
+ hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
1157
+
1158
+ @on_accelerate_main_proc(local=True, _sync=True)
1159
+ def pull_from_repo(
1160
+ self,
1161
+ allow_patterns: list[str] | str | None = None,
1162
+ ignore_patterns: list[str] | str | None = None,
1163
+ ) -> None:
1164
+ snapshot_download(
1165
+ self.repo_id,
1166
+ repo_type="dataset",
1167
+ revision=self.revision,
1168
+ local_dir=self.root,
1169
+ allow_patterns=allow_patterns,
1170
+ ignore_patterns=ignore_patterns,
1171
+ )
1172
+
1173
+ def download_episodes(self, download_videos: bool = True) -> None:
1174
+ """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
1175
+ will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
1176
+ dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
1177
+ in 'local_dir', they won't be downloaded again.
1178
+ """
1179
+ # TODO(rcadene, aliberts): implement faster transfer
1180
+ # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
1181
+ files = None
1182
+ ignore_patterns = None if download_videos else "videos/"
1183
+ if self.episodes is not None:
1184
+ files = self.get_episodes_file_paths()
1185
+
1186
+ self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
1187
+
1188
+ def get_episodes_file_paths(self) -> list[Path]:
1189
+ """Get file paths for all selected episodes.
1190
+
1191
+ Returns paths for both parquet data files and video files (if applicable)
1192
+ for all episodes in the dataset.
1193
+
1194
+ Returns:
1195
+ List of file paths for episode data and videos.
1196
+ """
1197
+ episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
1198
+ fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
1199
+ if len(self.meta.video_keys) > 0:
1200
+ video_files = [
1201
+ str(self.meta.get_video_file_path(ep_idx, vid_key))
1202
+ for vid_key in self.meta.video_keys
1203
+ for ep_idx in episodes
1204
+ ]
1205
+ fpaths += video_files
1206
+
1207
+ return fpaths
1208
+
1209
+ def load_hf_dataset(self) -> datasets.Dataset:
1210
+ """hf_dataset contains all the observations, states, actions, rewards, etc."""
1211
+ if self.episodes is None:
1212
+ path = str(self.root / "data")
1213
+ hf_dataset = load_dataset("parquet", data_dir=path, split="train")
1214
+ else:
1215
+ files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
1216
+ hf_dataset = load_dataset("parquet", data_files=files, split="train")
1217
+
1218
+ # TODO(aliberts): hf_dataset.set_format("torch")
1219
+ hf_dataset.set_transform(hf_transform_to_torch)
1220
+ return hf_dataset
1221
+
1222
+ def create_hf_dataset(self) -> datasets.Dataset:
1223
+ """Create an empty HuggingFace dataset with the correct features.
1224
+
1225
+ Returns:
1226
+ Empty dataset with features matching the dataset specification.
1227
+ """
1228
+ features = get_hf_features_from_features(self.features)
1229
+ ft_dict = {col: [] for col in features}
1230
+ hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
1231
+
1232
+ # TODO(aliberts): hf_dataset.set_format("torch")
1233
+ hf_dataset.set_transform(hf_transform_to_torch)
1234
+ return hf_dataset
1235
+
1236
+ @property
1237
+ def fps(self) -> int:
1238
+ """Frames per second used during data collection."""
1239
+ return self.meta.fps
1240
+
1241
+ @property
1242
+ def num_frames(self) -> int:
1243
+ """Number of frames in selected episodes."""
1244
+ return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
1245
+
1246
+ @property
1247
+ def num_episodes(self) -> int:
1248
+ """Number of episodes selected."""
1249
+ return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
1250
+
1251
+ @property
1252
+ def features(self) -> dict[str, dict]:
1253
+ return self.meta.features
1254
+
1255
+ @property
1256
+ def hf_features(self) -> datasets.Features:
1257
+ """Features of the hf_dataset."""
1258
+ if self.hf_dataset is not None:
1259
+ return self.hf_dataset.features
1260
+ else:
1261
+ return get_hf_features_from_features(self.features)
1262
+
1263
+ def _get_query_indices_soft(
1264
+ self, idx: int, ep_idx: int
1265
+ ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
1266
+ """Get soft (float) indices for querying features with delta timestamps.
1267
+
1268
+ Computes indices for features based on delta timestamps, accounting for
1269
+ episode boundaries. Returns both query indices and padding masks.
1270
+
1271
+ Args:
1272
+ idx: Current data index.
1273
+ ep_idx: Current episode index.
1274
+
1275
+ Returns:
1276
+ Tuple of (query_indices, padding):
1277
+ - query_indices: Dictionary mapping feature names to soft indices.
1278
+ - padding: Dictionary mapping feature names to boolean padding masks.
1279
+ """
1280
+ ep_start = self.episode_data_index["from"][self.epi2idx[ep_idx]].item()
1281
+ ep_end = self.episode_data_index["to"][self.epi2idx[ep_idx]].item()
1282
+
1283
+ # Get the delta_indices by group
1284
+ delta_indices = get_delta_indices_soft(self.delta_timestamps_params, self.fps)
1285
+ # Map from group to feature
1286
+ delta_indices = {
1287
+ feature: delta_indices[group][
1288
+ slice(None) if indices is None else [indices] if isinstance(indices, int) else indices
1289
+ ]
1290
+ for feature, (group, indices) in self.feature2group.items()
1291
+ }
1292
+ query_indices = {
1293
+ key: np.clip(idx + delta_idx, ep_start, ep_end - 1) for key, delta_idx in delta_indices.items()
1294
+ }
1295
+ padding = { # Pad values outside of current episode range
1296
+ f"{key}_is_pad": torch.BoolTensor((idx + delta_idx < ep_start) | (idx + delta_idx >= ep_end))
1297
+ for key, delta_idx in delta_indices.items()
1298
+ }
1299
+ return query_indices, padding
1300
+
1301
+ def _get_query_timestamps(
1302
+ self,
1303
+ current_ts: float,
1304
+ query_indices: dict[str, np.ndarray] | None = None,
1305
+ ) -> dict[str, np.ndarray]:
1306
+ """Get query timestamps for video features.
1307
+
1308
+ Converts soft indices to timestamps for video frame extraction.
1309
+ If query_indices is provided, uses them; otherwise uses current timestamp.
1310
+
1311
+ Args:
1312
+ current_ts: Current timestamp in seconds.
1313
+ query_indices: Optional dictionary of soft indices for features.
1314
+
1315
+ Returns:
1316
+ Dictionary mapping video keys to query timestamps.
1317
+ """
1318
+ if query_indices:
1319
+ # In case values are lists
1320
+ query_indices = {k: np.array(v, dtype=np.float32) for k, v in query_indices.items()}
1321
+ q_indices = next(iter(query_indices.values()))
1322
+ # Pick any (soft) row index, which is guaranteed to be within [ep_start, ep_end), then take the floor
1323
+ in_ep_row_idx = math.floor(q_indices[0])
1324
+ # Index of the episode (not index of row). E.g., episode_index = 36 for row index = 10000
1325
+ ep_idx = self.hf_dataset.select([in_ep_row_idx])["episode_index"][0].item()
1326
+ # Row index where the current episode start
1327
+ ep_start_row_idx = self.episode_data_index["from"][self.epi2idx[ep_idx]].item()
1328
+ else:
1329
+ ep_start_row_idx = None
1330
+
1331
+ query_timestamps = {}
1332
+ for key in self.meta.video_keys:
1333
+ if query_indices is not None and key in query_indices:
1334
+ query_timestamps[key] = (query_indices[key] - ep_start_row_idx) / self.fps
1335
+ else:
1336
+ query_timestamps[key] = np.array([current_ts], dtype=np.float32)
1337
+
1338
+ return query_timestamps
1339
+
1340
+ def _query_hf_dataset_soft(self, soft_indices: dict[str, np.ndarray]) -> dict:
1341
+ """Query dataset using soft (float) indices with interpolation.
1342
+
1343
+ Converts soft indices to hard indices based on resample strategy
1344
+ (linear interpolation or nearest neighbor).
1345
+
1346
+ Args:
1347
+ soft_indices: Dictionary mapping feature names to soft (float) indices.
1348
+
1349
+ Returns:
1350
+ Dictionary of feature values queried from the dataset.
1351
+
1352
+ Raises:
1353
+ ValueError: If vector_resample_strategy is not 'linear' or 'nearest'.
1354
+ """
1355
+ # soft indices are float indices that need to be converted to hard (integer) indices
1356
+ if self.vector_resample_strategy == "linear":
1357
+ floor_indices = {k: np.floor(v).astype(int) for k, v in soft_indices.items()}
1358
+ dist2floor = {k: v - floor_indices[k] for k, v in soft_indices.items()}
1359
+ # In the unlikely case that the soft index is exactly (ep_end - 1), floor will (ep_end - 1), and (floor + 1)
1360
+ # will be ep_end, which may be out of bounds (despite usually being the start of the next episode).
1361
+ # Therefore, we add 0 instead of 1 whenever the distance to floor is 0.
1362
+ ceil_indices = {k: floor_indices[k] + (dist2floor[k] > 0.0) for k, v in soft_indices.items()}
1363
+ q_floor = self._query_hf_dataset(floor_indices)
1364
+ q_ceil = self._query_hf_dataset(ceil_indices)
1365
+
1366
+ item = {}
1367
+ for k, d2f in dist2floor.items():
1368
+ if k not in q_floor:
1369
+ continue
1370
+ d2f = torch.tensor(d2f)
1371
+ d2f = rearrange(d2f, f"n -> {'n' + ' 1' * (q_floor[k].ndim - 1)}")
1372
+ item[k] = (1.0 - d2f) * q_floor[k] + d2f * q_ceil[k]
1373
+ return item
1374
+ elif self.vector_resample_strategy == "nearest":
1375
+ hard_indices = {k: v.round().astype(int) for k, v in soft_indices.items()}
1376
+ return self._query_hf_dataset(hard_indices)
1377
+
1378
+ raise ValueError(
1379
+ f"Unsupported vector_resample_strategy: {self.vector_resample_strategy}. Choose 'linear' or 'nearest'."
1380
+ )
1381
+
1382
+ def _query_hf_dataset(self, hard_indices: dict[str, np.ndarray]) -> dict:
1383
+ """Query dataset using hard (integer) indices.
1384
+
1385
+ Args:
1386
+ hard_indices: Dictionary mapping feature names to integer indices.
1387
+
1388
+ Returns:
1389
+ Dictionary of feature values stacked as tensors.
1390
+ """
1391
+ # TODO(shuheng): look into optimization when using hf_dataset.select
1392
+ return {
1393
+ key: torch.stack(list(self.hf_dataset.select(q_idx)[key]))
1394
+ for key, q_idx in hard_indices.items()
1395
+ if key not in self.meta.video_keys
1396
+ }
1397
+
1398
+ def _query_videos(self, query_timestamps: dict[str, np.ndarray], ep_idx: int) -> dict[str, torch.Tensor]:
1399
+ """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
1400
+ in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
1401
+ Segmentation Fault. This probably happens because a memory reference to the video loader is created in
1402
+ the main process and a subprocess fails to access it.
1403
+ """
1404
+ item = {}
1405
+ for vid_key, query_ts in query_timestamps.items():
1406
+ video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
1407
+ frame_indices_soft = query_ts * self.fps
1408
+ if self.image_resample_strategy == "linear":
1409
+ frame_indices_floor = np.floor(frame_indices_soft).astype(int)
1410
+ dist2floor = frame_indices_soft - frame_indices_floor
1411
+ frame_indices_ceil = np.floor(frame_indices_soft) + 1.0 * (dist2floor > 0.0)
1412
+ query_ts_floor = (frame_indices_floor / self.fps).tolist()
1413
+ query_ts_ceil = (frame_indices_ceil / self.fps).tolist()
1414
+ frames_floor = decode_video_frames(
1415
+ video_path, query_ts_floor, self.tolerance_s, self.video_backend
1416
+ )
1417
+ frames_ceil = decode_video_frames(
1418
+ video_path, query_ts_ceil, self.tolerance_s, self.video_backend
1419
+ )
1420
+ dist2floor = dist2floor[:, None, None, None]
1421
+ frames = frames_ceil * dist2floor + frames_floor * (1 - dist2floor)
1422
+ elif self.image_resample_strategy == "nearest":
1423
+ query_ts_rounded = (frame_indices_soft.round() / self.fps).tolist()
1424
+ frames = decode_video_frames(
1425
+ video_path, query_ts_rounded, self.tolerance_s, self.video_backend
1426
+ )
1427
+ else:
1428
+ raise ValueError(
1429
+ f"Unsupported image_resample_strategy: {self.image_resample_strategy}. Choose 'linear' or 'nearest'."
1430
+ )
1431
+ item[vid_key] = frames.squeeze(0)
1432
+
1433
+ return item
1434
+
1435
+ def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
1436
+ """Add padding mask keys to the item dictionary.
1437
+
1438
+ Args:
1439
+ item: Item dictionary to modify.
1440
+ padding: Dictionary mapping feature names to boolean padding masks.
1441
+
1442
+ Returns:
1443
+ Modified item dictionary with padding keys added.
1444
+ """
1445
+ for key, val in padding.items():
1446
+ item[key] = torch.BoolTensor(val)
1447
+ return item
1448
+
1449
+ def __len__(self):
1450
+ return self.num_frames
1451
+
1452
+ @retry_random_on_failure
1453
+ def __getitem__(self, idx) -> dict:
1454
+ item = self.hf_dataset[idx]
1455
+ ep_idx = item["episode_index"].item()
1456
+
1457
+ if self.episode_data_index is not None and self.epi2idx is not None:
1458
+ ep_end = self.episode_data_index["to"][self.epi2idx[ep_idx]].item()
1459
+
1460
+ episodes_info = self.meta.episodes[ep_idx]
1461
+
1462
+ # Soft indices are floats instead of integers, which allows for different interpolation strategies such as
1463
+ # nearest neighbor or linear interpolation.
1464
+ query_indices_soft = None
1465
+ if self.delta_timestamps_params[0]:
1466
+ query_indices_soft, padding = self._get_query_indices_soft(idx, ep_idx)
1467
+ query_result = self._query_hf_dataset_soft(query_indices_soft)
1468
+ item = {**item, **padding}
1469
+ for key, val in query_result.items():
1470
+ item[key] = val
1471
+
1472
+ if len(self.meta.video_keys) > 0:
1473
+ current_ts = item["timestamp"].item()
1474
+ query_timestamps = self._get_query_timestamps(current_ts, query_indices_soft)
1475
+ video_frames = self._query_videos(query_timestamps, ep_idx)
1476
+ item = {**video_frames, **item}
1477
+
1478
+ if self.image_transforms is not None:
1479
+ image_keys = self.meta.camera_keys
1480
+ for cam in image_keys:
1481
+ item[cam] = self.image_transforms(item[cam])
1482
+
1483
+ # Add task as a string
1484
+ task_idx = item["task_index"].item()
1485
+ item["task"] = self.meta.tasks[task_idx]
1486
+
1487
+ # If indices is an int, squeeze the feature
1488
+ for feature, (_, indices) in self.feature2group.items():
1489
+ if isinstance(indices, int):
1490
+ item[feature] = item[feature].squeeze(0)
1491
+
1492
+ # The conversion script of AGI BOT dataset uses a dataloader to enumerate data and compute stats.
1493
+ # If we enable standardization, those stats will be computed under their mapped names, which is wrong.
1494
+
1495
+ if self.standardize:
1496
+ # Add response as a string
1497
+ if "response" not in item:
1498
+ item["response"] = ""
1499
+
1500
+ episode_index = item["episode_index"].item()
1501
+ # don't convert to timestamp to `float`, because torch.float64 is not supported on MPS
1502
+ timestamp = item["timestamp"]
1503
+
1504
+ # change data naming to standard data format
1505
+ item = self._to_standard_data_format(item)
1506
+
1507
+ if self.meta.advantages is not None:
1508
+ advantage = self.meta.advantages.get((episode_index, timestamp), 0)
1509
+ item["advantage"] = torch.tensor(advantage, dtype=torch.bfloat16)
1510
+ else:
1511
+ item["advantage"] = torch.tensor(0.0, dtype=torch.bfloat16)
1512
+
1513
+ success = episodes_info.get("success", True)
1514
+
1515
+ # only add the below fields to item when training or evaluating the value fns
1516
+ if isinstance(self.cfg.policy, ValueConfig):
1517
+ item["return_bin_idx"], item["return_continuous"] = calculate_return_bins_with_equal_width(
1518
+ success,
1519
+ self.cfg.policy.reward_config.number_of_bins,
1520
+ ep_end,
1521
+ self.cfg.policy.reward_config.reward_normalizer,
1522
+ idx,
1523
+ self.cfg.policy.reward_config.C_neg,
1524
+ )
1525
+
1526
+ item["return_bin_idx"] = torch.tensor(item["return_bin_idx"], dtype=torch.long)
1527
+ item["return_continuous"] = torch.tensor(item["return_continuous"], dtype=torch.float32)
1528
+ # success, episode_end_idx and last step is required for calculating advantage
1529
+ if self.return_advantage_input:
1530
+ item["success"] = success
1531
+ item["episode_end_idx"] = ep_end
1532
+ item["current_idx"] = idx
1533
+ item["last_step"] = idx + self.cfg.policy.reward_config.N_steps_look_ahead >= ep_end
1534
+ item["episode_index"] = episode_index
1535
+ item["timestamp"] = timestamp
1536
+ else:
1537
+ item["return_bin_idx"] = torch.tensor(0, dtype=torch.long)
1538
+ item["return_continuous"] = torch.tensor(0, dtype=torch.float32)
1539
+
1540
+ # sanity check for action chunk lengths
1541
+ assert item["actions"].shape[0] == self.cfg.action_chunk
1542
+ assert item["action_is_pad"].shape[0] == self.cfg.action_chunk
1543
+
1544
+ return item
1545
+
1546
+ def _get_feature_mapping_key(self) -> str:
1547
+ return self.repo_id
1548
+
1549
+ def __repr__(self):
1550
+ feature_keys = list(self.features)
1551
+ return (
1552
+ f"{self.__class__.__name__}({{\n"
1553
+ f" Repository ID: '{self.repo_id}',\n"
1554
+ f" Number of selected episodes: '{self.num_episodes}',\n"
1555
+ f" Number of selected samples: '{self.num_frames}',\n"
1556
+ f" Features: '{feature_keys}',\n"
1557
+ "})',\n"
1558
+ )
1559
+
1560
+ def create_episode_buffer(self, episode_index: int | None = None) -> dict:
1561
+ current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
1562
+ ep_buffer = {}
1563
+ # size and task are special cases that are not in self.features
1564
+ ep_buffer["size"] = 0
1565
+ ep_buffer["task"] = []
1566
+ for key in self.features:
1567
+ ep_buffer[key] = current_ep_idx if key == "episode_index" else []
1568
+ return ep_buffer
1569
+
1570
+ def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
1571
+ fpath = DEFAULT_IMAGE_PATH.format(
1572
+ image_key=image_key, episode_index=episode_index, frame_index=frame_index
1573
+ )
1574
+ return self.root / fpath
1575
+
1576
+ def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
1577
+ if self.image_writer is None:
1578
+ if isinstance(image, torch.Tensor):
1579
+ image = image.cpu().numpy()
1580
+ write_image(image, fpath)
1581
+ else:
1582
+ self.image_writer.save_image(image=image, fpath=fpath)
1583
+
1584
+ def add_frame(self, frame: dict) -> None:
1585
+ """
1586
+ This function only adds the frame to the episode_buffer. Apart from images — which are written in a
1587
+ temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
1588
+ then needs to be called.
1589
+ """
1590
+ # Convert torch to numpy if needed
1591
+ for name in frame:
1592
+ if isinstance(frame[name], torch.Tensor):
1593
+ frame[name] = frame[name].numpy()
1594
+
1595
+ validate_frame(frame, self.features)
1596
+
1597
+ if self.episode_buffer is None:
1598
+ self.episode_buffer = self.create_episode_buffer()
1599
+
1600
+ # Automatically add frame_index and timestamp to episode buffer
1601
+ frame_index = self.episode_buffer["size"]
1602
+ timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
1603
+ self.episode_buffer["frame_index"].append(frame_index)
1604
+ self.episode_buffer["timestamp"].append(timestamp)
1605
+
1606
+ # Add frame features to episode_buffer
1607
+ for key in frame:
1608
+ if key == "task":
1609
+ # Note: we associate the task in natural language to its task index during `save_episode`
1610
+ self.episode_buffer["task"].append(frame["task"])
1611
+ continue
1612
+
1613
+ if key not in self.features:
1614
+ raise ValueError(
1615
+ f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
1616
+ )
1617
+
1618
+ if self.features[key]["dtype"] in ["image", "video"]:
1619
+ img_path = self._get_image_file_path(
1620
+ episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
1621
+ )
1622
+ if frame_index == 0:
1623
+ img_path.parent.mkdir(parents=True, exist_ok=True)
1624
+ self._save_image(frame[key], img_path)
1625
+ self.episode_buffer[key].append(str(img_path))
1626
+ else:
1627
+ self.episode_buffer[key].append(frame[key])
1628
+
1629
+ self.episode_buffer["size"] += 1
1630
+
1631
+ def save_episode(self, episode_data: dict | None = None) -> None:
1632
+ """
1633
+ This will save to disk the current episode in self.episode_buffer.
1634
+
1635
+ Args:
1636
+ episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
1637
+ save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
1638
+ None.
1639
+ """
1640
+ if not episode_data:
1641
+ episode_buffer = self.episode_buffer
1642
+
1643
+ validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
1644
+
1645
+ # size and task are special cases that won't be added to hf_dataset
1646
+ episode_length = episode_buffer.pop("size")
1647
+ tasks = episode_buffer.pop("task")
1648
+ episode_tasks = list(set(tasks))
1649
+ episode_index = episode_buffer["episode_index"]
1650
+
1651
+ episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
1652
+ episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
1653
+
1654
+ # Add new tasks to the tasks dictionary
1655
+ for task in episode_tasks:
1656
+ task_index = self.meta.get_task_index(task)
1657
+ if task_index is None:
1658
+ self.meta.add_task(task)
1659
+
1660
+ # Given tasks in natural language, find their corresponding task indices
1661
+ episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
1662
+
1663
+ for key, ft in self.features.items():
1664
+ # index, episode_index, task_index are already processed above, and image and video
1665
+ # are processed separately by storing image path and frame info as meta data
1666
+ if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
1667
+ continue
1668
+ episode_buffer[key] = np.stack(episode_buffer[key])
1669
+
1670
+ self._wait_image_writer()
1671
+ self._save_episode_table(episode_buffer, episode_index)
1672
+ ep_stats = compute_episode_stats(episode_buffer, self.features)
1673
+
1674
+ if len(self.meta.video_keys) > 0:
1675
+ video_paths = self.encode_episode_videos(episode_index)
1676
+ for key in self.meta.video_keys:
1677
+ episode_buffer[key] = video_paths[key]
1678
+
1679
+ # `meta.save_episode` be executed after encoding the videos
1680
+ self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
1681
+
1682
+ ep_data_index, _ = get_episode_data_index(self.meta.episodes, [episode_index])
1683
+ ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
1684
+ check_timestamps_sync(
1685
+ episode_buffer["timestamp"],
1686
+ episode_buffer["episode_index"],
1687
+ ep_data_index_np,
1688
+ self.fps,
1689
+ self.tolerance_s,
1690
+ )
1691
+
1692
+ video_files = list(self.root.rglob("*.mp4"))
1693
+ assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
1694
+
1695
+ parquet_files = list(self.root.rglob("*.parquet"))
1696
+ assert len(parquet_files) == self.num_episodes
1697
+
1698
+ # delete images
1699
+ img_dir = self.root / "images"
1700
+ if img_dir.is_dir():
1701
+ shutil.rmtree(self.root / "images")
1702
+
1703
+ if not episode_data: # Reset the buffer
1704
+ self.episode_buffer = self.create_episode_buffer()
1705
+
1706
+ self.episode_data_index, self.epi2idx = get_episode_data_index(self.meta.episodes, self.episodes)
1707
+
1708
+ def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
1709
+ episode_dict = {key: episode_buffer[key] for key in self.hf_features}
1710
+ ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
1711
+ ep_dataset = embed_images(ep_dataset)
1712
+ self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
1713
+ self.hf_dataset.set_transform(hf_transform_to_torch)
1714
+ ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
1715
+ ep_data_path.parent.mkdir(parents=True, exist_ok=True)
1716
+ ep_dataset.to_parquet(ep_data_path)
1717
+
1718
+ def clear_episode_buffer(self) -> None:
1719
+ episode_index = self.episode_buffer["episode_index"]
1720
+ if self.image_writer is not None:
1721
+ for cam_key in self.meta.camera_keys:
1722
+ img_dir = self._get_image_file_path(
1723
+ episode_index=episode_index, image_key=cam_key, frame_index=0
1724
+ ).parent
1725
+ if img_dir.is_dir():
1726
+ shutil.rmtree(img_dir)
1727
+
1728
+ # Reset the buffer
1729
+ self.episode_buffer = self.create_episode_buffer()
1730
+
1731
+ def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
1732
+ if isinstance(self.image_writer, AsyncImageWriter):
1733
+ logging.warning(
1734
+ "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
1735
+ )
1736
+
1737
+ self.image_writer = AsyncImageWriter(
1738
+ num_processes=num_processes,
1739
+ num_threads=num_threads,
1740
+ )
1741
+
1742
+ def stop_image_writer(self) -> None:
1743
+ """
1744
+ Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
1745
+ remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
1746
+ """
1747
+ if self.image_writer is not None:
1748
+ self.image_writer.stop()
1749
+ self.image_writer = None
1750
+
1751
+ def _wait_image_writer(self) -> None:
1752
+ """Wait for asynchronous image writer to finish."""
1753
+ if self.image_writer is not None:
1754
+ self.image_writer.wait_until_done()
1755
+
1756
+ def encode_videos(self) -> None:
1757
+ """
1758
+ Use ffmpeg to convert frames stored as png into mp4 videos.
1759
+ Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
1760
+ since video encoding with ffmpeg is already using multithreading.
1761
+ """
1762
+ for ep_idx in range(self.meta.total_episodes):
1763
+ self.encode_episode_videos(ep_idx)
1764
+
1765
+ def encode_episode_videos(self, episode_index: int) -> dict:
1766
+ """
1767
+ Use ffmpeg to convert frames stored as png into mp4 videos.
1768
+ Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
1769
+ since video encoding with ffmpeg is already using multithreading.
1770
+ """
1771
+ video_paths = {}
1772
+ for key in self.meta.video_keys:
1773
+ video_path = self.root / self.meta.get_video_file_path(episode_index, key)
1774
+ video_paths[key] = str(video_path)
1775
+ if video_path.is_file():
1776
+ # Skip if video is already encoded. Could be the case when resuming data recording.
1777
+ continue
1778
+ img_dir = self._get_image_file_path(
1779
+ episode_index=episode_index, image_key=key, frame_index=0
1780
+ ).parent
1781
+ encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
1782
+
1783
+ return video_paths
1784
+
1785
+ def _separate_image_in_time(self, item: dict):
1786
+ name_map = DATA_FEATURES_NAME_MAPPING[self._get_feature_mapping_key()]
1787
+ cam_keys = {v for k, v in name_map.items() if k.startswith("camera")}
1788
+ for k in cam_keys:
1789
+ images = item.pop(k)
1790
+ assert len(images) == 2, (
1791
+ f"{k} in {self.__class__} is expected to have length 2, got shape={images.shape}"
1792
+ )
1793
+ item[k + "_local"], item[k] = images
1794
+
1795
+ pads = item.pop(k + "_is_pad")
1796
+ assert len(pads) == 2, (
1797
+ f"{k} in {self.__class__} is expected to have length 2, got shape={pads.shape}"
1798
+ )
1799
+ item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
1800
+
1801
+ @staticmethod
1802
+ def compute_delta_params(
1803
+ mean: dict[str, np.ndarray | list[float]],
1804
+ std: dict[str, np.ndarray | list[float]],
1805
+ lower: dict[str, np.ndarray | list[float]],
1806
+ upper: dict[str, np.ndarray | list[float]],
1807
+ ):
1808
+ r"""Process the parameters `mean`, `std`, `lower` and `upper` for delta timestamps.
1809
+ Delta timestamps will be computed dynamically in `__getitem__` with `clip(dT, lower, upper)` where `dT` follows
1810
+ the gaussian distribution N(mean, std^2). Each parameter is a dictionary mapping group names to sequences of
1811
+ floats.
1812
+
1813
+ For example, mean = {"group1": [-0.1, 0.0, 0.1], "group2": [0.0, 0.2]}. indicates that 3 delta timestamps for
1814
+ features in group1 will be sampled: time t-0.1, t, and t+0.1; and 2 delta timestamps for features in group2 will
1815
+ be sampled: time t and t+0.2, where t is the timestamp of the current data point.
1816
+
1817
+ It is assumed that the `std`, `lower`, and `upper` have the same keys as `mean`, and matching keys have values
1818
+ of the same length. If a key absent from `std`, `lower`, or `upper`, it will be set to a default value.
1819
+ Namely, `std` will be set to all 0, `lower` will be set to all `-inf`, and `upper` will be set to `+inf`, with
1820
+ lengths equal to the length of sequences in `mean` for that key.
1821
+ If a key is absent from `mean` but present in `std`, `lower`, or `upper`, it will be ignored.
1822
+
1823
+ After processing, the function returns four dictionaries: `mean`, `std`, `lower`, and `upper`, where each key
1824
+ is a feature name and each value is a numpy array of floats, satisfying the above conditions.
1825
+ """
1826
+ inf = float("inf")
1827
+ mean = mean or {}
1828
+ mean = {k: np.array(v) for k, v in mean.items()}
1829
+
1830
+ std = std or {}
1831
+ std = {k: np.array(std.get(k) or np.zeros_like(v)) for k, v in mean.items()}
1832
+
1833
+ lower = lower or {}
1834
+ lower = {k: np.array(lower.get(k) or (np.zeros_like(v) - inf)) for k, v in mean.items()}
1835
+
1836
+ upper = upper or {}
1837
+ upper = {k: np.array(upper.get(k) or (np.zeros_like(v) + inf)) for k, v in mean.items()}
1838
+
1839
+ for k in mean:
1840
+ if not (mean[k].shape == std[k].shape == lower[k].shape == upper[k].shape):
1841
+ raise ValueError(
1842
+ f"Delta timestamps parameters for {k} have inconsistent shapes: "
1843
+ f"mean={mean[k].shape}, std={std[k].shape}, lower={lower[k].shape}, upper={upper[k].shape}"
1844
+ )
1845
+
1846
+ return mean, std, lower, upper
1847
+
1848
+ def _check_feature_group_mapping(self):
1849
+ for feature, (group, indices) in self.feature2group.items():
1850
+ if group not in self.delta_timestamps_params[0]:
1851
+ raise ValueError(
1852
+ f"Feature '{feature}' is mapped to group '{group}', which is not present in "
1853
+ "delta_timestamps_params. Please check the mapping."
1854
+ )
1855
+ if indices is not None and not isinstance(indices, (int, list)):
1856
+ raise ValueError(
1857
+ f"Indices for feature '{feature}' in group '{group}' should be a list, an int, or None"
1858
+ )
1859
+
1860
+ @classmethod
1861
+ def create(
1862
+ cls,
1863
+ repo_id: str,
1864
+ fps: int,
1865
+ root: str | Path | None = None,
1866
+ robot_type: str | None = None,
1867
+ features: dict | None = None,
1868
+ use_videos: bool = True,
1869
+ tolerance_s: float = 1e-4,
1870
+ image_writer_processes: int = 0,
1871
+ image_writer_threads: int = 0,
1872
+ video_backend: str | None = None,
1873
+ image_resample_strategy: str = "nearest",
1874
+ vector_resample_strategy: str = "nearest",
1875
+ standardize: bool = True,
1876
+ ) -> "LeRobotDataset":
1877
+ """Create a LeRobot Dataset from scratch in order to record data."""
1878
+ obj = cls.__new__(cls)
1879
+ obj.meta = LeRobotDatasetMetadata.create(
1880
+ repo_id=repo_id,
1881
+ fps=fps,
1882
+ root=root,
1883
+ robot_type=robot_type,
1884
+ features=features,
1885
+ use_videos=use_videos,
1886
+ )
1887
+ obj.repo_id = obj.meta.repo_id
1888
+ obj.root = obj.meta.root
1889
+ obj.revision = None
1890
+ obj.tolerance_s = tolerance_s
1891
+ obj.image_writer = None
1892
+
1893
+ if image_writer_processes or image_writer_threads:
1894
+ obj.start_image_writer(image_writer_processes, image_writer_threads)
1895
+
1896
+ # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
1897
+ obj.episode_buffer = obj.create_episode_buffer()
1898
+
1899
+ obj.episodes = None
1900
+ obj.hf_dataset = obj.create_hf_dataset()
1901
+ obj.image_transforms = None
1902
+ obj.delta_timestamps_params = obj.compute_delta_params(None, None, None, None)
1903
+ obj.feature2group = {}
1904
+ obj.episode_data_index = None
1905
+ obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
1906
+ obj.image_resample_strategy = image_resample_strategy
1907
+ obj.vector_resample_strategy = vector_resample_strategy
1908
+ obj.standardize = standardize
1909
+ obj.episode_data_index, obj.epi2idx = get_episode_data_index(obj.meta.episodes, obj.episodes)
1910
+ return obj