opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1243 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Utility functions for dataset management, I/O, and validation.
|
|
18
|
+
|
|
19
|
+
This module provides a comprehensive set of utility functions for working with
|
|
20
|
+
LeRobot datasets, including file I/O operations, metadata management, data
|
|
21
|
+
validation, version compatibility checking, and HuggingFace Hub integration.
|
|
22
|
+
|
|
23
|
+
The module is organized into several functional areas:
|
|
24
|
+
|
|
25
|
+
* Dictionary manipulation: Flattening/unflattening nested dictionaries
|
|
26
|
+
* File I/O: JSON and JSONL reading/writing with automatic directory creation
|
|
27
|
+
* Metadata management: Loading and saving dataset info, statistics, episodes,
|
|
28
|
+
tasks, and advantages
|
|
29
|
+
* Data validation: Frame and episode buffer validation with detailed error
|
|
30
|
+
messages
|
|
31
|
+
|
|
32
|
+
Key Features:
|
|
33
|
+
* Automatic serialization: Converts tensors and arrays to JSON-compatible
|
|
34
|
+
formats.
|
|
35
|
+
* Comprehensive validation: Validates frames and episodes.
|
|
36
|
+
* Path management: Standard paths for dataset structure (meta/, data/).
|
|
37
|
+
|
|
38
|
+
Constants:
|
|
39
|
+
DEFAULT_CHUNK_SIZE: Maximum number of episodes per chunk (1000).
|
|
40
|
+
ADVANTAGES_PATH, INFO_PATH, EPISODES_PATH, STATS_PATH: Standard paths.
|
|
41
|
+
|
|
42
|
+
Classes:
|
|
43
|
+
IterableNamespace: Namespace object supporting both dictionary iteration
|
|
44
|
+
and dot notation access.
|
|
45
|
+
|
|
46
|
+
Functions:
|
|
47
|
+
Dictionary manipulation:
|
|
48
|
+
flatten_dict: Flatten nested dictionaries with separator-based keys.
|
|
49
|
+
unflatten_dict: Expand flattened keys into nested dictionaries.
|
|
50
|
+
serialize_dict: Convert tensors/arrays to JSON-serializable format.
|
|
51
|
+
|
|
52
|
+
File I/O:
|
|
53
|
+
load_json, write_json: JSON file operations.
|
|
54
|
+
load_jsonlines, write_jsonlines: JSONL operations.
|
|
55
|
+
|
|
56
|
+
Data validation:
|
|
57
|
+
validate_frame: Validate frame data against feature specifications.
|
|
58
|
+
validate_episode_buffer: Validate episode buffer before adding.
|
|
59
|
+
|
|
60
|
+
(Note: Truncated for brevity, apply the same flat indentation to the rest)
|
|
61
|
+
|
|
62
|
+
Example:
|
|
63
|
+
Load dataset metadata::
|
|
64
|
+
|
|
65
|
+
>>> info = load_info(Path("my_dataset"))
|
|
66
|
+
>>> stats = load_stats(Path("my_dataset"))
|
|
67
|
+
>>> episodes = load_episodes(Path("my_dataset"))
|
|
68
|
+
|
|
69
|
+
Validate a frame::
|
|
70
|
+
|
|
71
|
+
>>> features = {"state": {"dtype": "float32", "shape": (7,)}}
|
|
72
|
+
>>> frame = {"state": np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])}
|
|
73
|
+
>>> validate_frame(frame, features)
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
import contextlib
|
|
77
|
+
import importlib.resources
|
|
78
|
+
import json
|
|
79
|
+
import logging
|
|
80
|
+
from collections.abc import Iterator
|
|
81
|
+
from itertools import accumulate
|
|
82
|
+
from pathlib import Path
|
|
83
|
+
from pprint import pformat
|
|
84
|
+
from types import SimpleNamespace
|
|
85
|
+
from typing import Any
|
|
86
|
+
|
|
87
|
+
import datasets
|
|
88
|
+
import jsonlines
|
|
89
|
+
import numpy as np
|
|
90
|
+
import packaging.version
|
|
91
|
+
import torch
|
|
92
|
+
from datasets.table import embed_table_storage
|
|
93
|
+
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
|
94
|
+
from huggingface_hub.errors import RevisionNotFoundError
|
|
95
|
+
from PIL import Image as PILImage
|
|
96
|
+
from torchvision import transforms
|
|
97
|
+
|
|
98
|
+
from opentau.configs.types import DictLike, FeatureType, PolicyFeature
|
|
99
|
+
from opentau.datasets.backward_compatibility import (
|
|
100
|
+
V21_MESSAGE,
|
|
101
|
+
BackwardCompatibilityError,
|
|
102
|
+
ForwardCompatibilityError,
|
|
103
|
+
)
|
|
104
|
+
from opentau.utils.utils import is_valid_numpy_dtype_string
|
|
105
|
+
|
|
106
|
+
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
|
107
|
+
|
|
108
|
+
ADVANTAGES_PATH = "meta/advantages.json"
|
|
109
|
+
INFO_PATH = "meta/info.json"
|
|
110
|
+
EPISODES_PATH = "meta/episodes.jsonl"
|
|
111
|
+
STATS_PATH = "meta/stats.json"
|
|
112
|
+
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
|
113
|
+
TASKS_PATH = "meta/tasks.jsonl"
|
|
114
|
+
|
|
115
|
+
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
|
116
|
+
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
|
117
|
+
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
|
118
|
+
|
|
119
|
+
DATASET_CARD_TEMPLATE = """
|
|
120
|
+
---
|
|
121
|
+
# Metadata will go there
|
|
122
|
+
---
|
|
123
|
+
This dataset was created using [OpenTau](https://github.com/TensorAuto/OpenTau).
|
|
124
|
+
|
|
125
|
+
## {}
|
|
126
|
+
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
DEFAULT_FEATURES = {
|
|
130
|
+
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
|
131
|
+
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
|
132
|
+
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
|
133
|
+
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
|
134
|
+
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
|
139
|
+
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
|
140
|
+
|
|
141
|
+
For example::
|
|
142
|
+
|
|
143
|
+
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
|
|
144
|
+
>>> print(flatten_dict(dct))
|
|
145
|
+
{"a/b": 1, "a/c/d": 2, "e": 3}
|
|
146
|
+
"""
|
|
147
|
+
items = []
|
|
148
|
+
for k, v in d.items():
|
|
149
|
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
|
150
|
+
if isinstance(v, dict):
|
|
151
|
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
|
152
|
+
else:
|
|
153
|
+
items.append((new_key, v))
|
|
154
|
+
return dict(items)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
|
158
|
+
"""Unflatten a dictionary by expanding keys with separators into nested dictionaries.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
d: Dictionary with flattened keys (e.g., {"a/b": 1, "a/c/d": 2}).
|
|
162
|
+
sep: Separator used to split keys. Defaults to "/".
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Nested dictionary structure (e.g., {"a": {"b": 1, "c": {"d": 2}}}).
|
|
166
|
+
|
|
167
|
+
Example:
|
|
168
|
+
>>> dct = {"a/b": 1, "a/c/d": 2, "e": 3}
|
|
169
|
+
>>> print(unflatten_dict(dct))
|
|
170
|
+
{"a": {"b": 1, "c": {"d": 2}}, "e": 3}
|
|
171
|
+
"""
|
|
172
|
+
outdict = {}
|
|
173
|
+
for key, value in d.items():
|
|
174
|
+
parts = key.split(sep)
|
|
175
|
+
d = outdict
|
|
176
|
+
for part in parts[:-1]:
|
|
177
|
+
if part not in d:
|
|
178
|
+
d[part] = {}
|
|
179
|
+
d = d[part]
|
|
180
|
+
d[parts[-1]] = value
|
|
181
|
+
return outdict
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
|
185
|
+
"""Get a nested item from a dictionary-like object using a flattened key.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
obj: Dictionary-like object to access.
|
|
189
|
+
flattened_key: Flattened key path (e.g., "a/b/c").
|
|
190
|
+
sep: Separator used in the flattened key. Defaults to "/".
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
The value at the nested path specified by the flattened key.
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
>>> dct = {"a": {"b": {"c": 42}}}
|
|
197
|
+
>>> get_nested_item(dct, "a/b/c")
|
|
198
|
+
42
|
|
199
|
+
"""
|
|
200
|
+
split_keys = flattened_key.split(sep)
|
|
201
|
+
getter = obj[split_keys[0]]
|
|
202
|
+
if len(split_keys) == 1:
|
|
203
|
+
return getter
|
|
204
|
+
|
|
205
|
+
for key in split_keys[1:]:
|
|
206
|
+
getter = getter[key]
|
|
207
|
+
|
|
208
|
+
return getter
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
|
212
|
+
"""Serialize a dictionary containing tensors and arrays to JSON-serializable format.
|
|
213
|
+
|
|
214
|
+
Converts torch.Tensor and np.ndarray to lists, and np.generic to Python scalars.
|
|
215
|
+
The dictionary structure is preserved through flattening and unflattening.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
stats: Dictionary containing statistics with tensor/array values.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Dictionary with serialized (list/scalar) values in the same structure.
|
|
222
|
+
|
|
223
|
+
Raises:
|
|
224
|
+
NotImplementedError: If a value type is not supported for serialization.
|
|
225
|
+
"""
|
|
226
|
+
serialized_dict = {}
|
|
227
|
+
for key, value in flatten_dict(stats).items():
|
|
228
|
+
if isinstance(value, (torch.Tensor, np.ndarray)):
|
|
229
|
+
serialized_dict[key] = value.tolist()
|
|
230
|
+
elif isinstance(value, np.generic):
|
|
231
|
+
serialized_dict[key] = value.item()
|
|
232
|
+
elif isinstance(value, (int, float)):
|
|
233
|
+
serialized_dict[key] = value
|
|
234
|
+
else:
|
|
235
|
+
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
|
236
|
+
return unflatten_dict(serialized_dict)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
|
240
|
+
"""Embed image bytes into the dataset table before saving to parquet.
|
|
241
|
+
|
|
242
|
+
Converts the dataset to arrow format, embeds image storage, and restores
|
|
243
|
+
the original format.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
dataset: HuggingFace dataset containing images.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Dataset with embedded image bytes, ready for parquet serialization.
|
|
250
|
+
"""
|
|
251
|
+
format = dataset.format
|
|
252
|
+
dataset = dataset.with_format("arrow")
|
|
253
|
+
dataset = dataset.map(embed_table_storage, batched=False)
|
|
254
|
+
dataset = dataset.with_format(**format)
|
|
255
|
+
return dataset
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def load_json(fpath: Path) -> Any:
|
|
259
|
+
"""Load JSON data from a file.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
fpath: Path to the JSON file.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Parsed JSON data (dict, list, or primitive type).
|
|
266
|
+
"""
|
|
267
|
+
with open(fpath) as f:
|
|
268
|
+
return json.load(f)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def write_json(data: dict, fpath: Path) -> None:
|
|
272
|
+
"""Write data to a JSON file.
|
|
273
|
+
|
|
274
|
+
Creates parent directories if they don't exist. Uses 4-space indentation
|
|
275
|
+
and allows non-ASCII characters.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
data: Dictionary or other JSON-serializable data to write.
|
|
279
|
+
fpath: Path where the JSON file will be written.
|
|
280
|
+
"""
|
|
281
|
+
fpath.parent.mkdir(exist_ok=True, parents=True)
|
|
282
|
+
with open(fpath, "w") as f:
|
|
283
|
+
json.dump(data, f, indent=4, ensure_ascii=False)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def load_jsonlines(fpath: Path) -> list[Any]:
|
|
287
|
+
"""Load JSON Lines (JSONL) data from a file.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
fpath: Path to the JSONL file.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
List of dictionaries, one per line in the file.
|
|
294
|
+
"""
|
|
295
|
+
with jsonlines.open(fpath, "r") as reader:
|
|
296
|
+
return list(reader)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def write_jsonlines(data: dict, fpath: Path) -> None:
|
|
300
|
+
"""Write data to a JSON Lines (JSONL) file.
|
|
301
|
+
|
|
302
|
+
Creates parent directories if they don't exist. Writes each item in the
|
|
303
|
+
data iterable as a separate line.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
data: Iterable of dictionaries to write (one per line).
|
|
307
|
+
fpath: Path where the JSONL file will be written.
|
|
308
|
+
"""
|
|
309
|
+
fpath.parent.mkdir(exist_ok=True, parents=True)
|
|
310
|
+
with jsonlines.open(fpath, "w") as writer:
|
|
311
|
+
writer.write_all(data)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def append_jsonlines(data: dict, fpath: Path) -> None:
|
|
315
|
+
"""Append a single dictionary to a JSON Lines (JSONL) file.
|
|
316
|
+
|
|
317
|
+
Creates parent directories if they don't exist. Appends the data as a
|
|
318
|
+
new line to the existing file.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
data: Dictionary to append as a new line.
|
|
322
|
+
fpath: Path to the JSONL file (will be created if it doesn't exist).
|
|
323
|
+
"""
|
|
324
|
+
fpath.parent.mkdir(exist_ok=True, parents=True)
|
|
325
|
+
with jsonlines.open(fpath, "a") as writer:
|
|
326
|
+
writer.write(data)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def write_info(info: dict, local_dir: Path) -> None:
|
|
330
|
+
"""Write dataset info dictionary to the standard info.json file.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
info: Dataset info dictionary to write.
|
|
334
|
+
local_dir: Root directory of the dataset where meta/info.json will be written.
|
|
335
|
+
"""
|
|
336
|
+
write_json(info, local_dir / INFO_PATH)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def load_info(local_dir: Path) -> dict:
|
|
340
|
+
"""Load dataset info from the standard info.json file.
|
|
341
|
+
|
|
342
|
+
Converts feature shapes from lists to tuples for consistency.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
local_dir: Root directory of the dataset containing meta/info.json.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Dataset info dictionary with feature shapes as tuples.
|
|
349
|
+
"""
|
|
350
|
+
info = load_json(local_dir / INFO_PATH)
|
|
351
|
+
for ft in info["features"].values():
|
|
352
|
+
ft["shape"] = tuple(ft["shape"])
|
|
353
|
+
return info
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def write_stats(stats: dict, local_dir: Path) -> None:
|
|
357
|
+
"""Write dataset statistics to the standard stats.json file.
|
|
358
|
+
|
|
359
|
+
Serializes tensors and arrays to JSON-compatible format before writing.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
stats: Dictionary containing dataset statistics (may contain tensors/arrays).
|
|
363
|
+
local_dir: Root directory of the dataset where meta/stats.json will be written.
|
|
364
|
+
"""
|
|
365
|
+
serialized_stats = serialize_dict(stats)
|
|
366
|
+
write_json(serialized_stats, local_dir / STATS_PATH)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
|
370
|
+
"""Convert statistics dictionary values to numpy arrays.
|
|
371
|
+
|
|
372
|
+
Flattens the dictionary, converts all values to numpy arrays, then
|
|
373
|
+
unflattens to restore the original structure.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
stats: Dictionary with statistics (values may be lists or other types).
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
Dictionary with the same structure but all values as numpy arrays.
|
|
380
|
+
"""
|
|
381
|
+
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
|
382
|
+
return unflatten_dict(stats)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
|
386
|
+
"""Load dataset statistics from the standard stats.json file.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
local_dir: Root directory of the dataset containing meta/stats.json.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
Dictionary with statistics as numpy arrays, or None if the file doesn't exist.
|
|
393
|
+
"""
|
|
394
|
+
if not (local_dir / STATS_PATH).exists():
|
|
395
|
+
return None
|
|
396
|
+
stats = load_json(local_dir / STATS_PATH)
|
|
397
|
+
return cast_stats_to_numpy(stats)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def load_advantages(local_dir: Path) -> dict:
|
|
401
|
+
"""Load advantage values from the advantages.json file.
|
|
402
|
+
|
|
403
|
+
Advantages are keyed by (episode_index, timestamp) tuples in the JSON file
|
|
404
|
+
as comma-separated strings, which are converted to tuple keys.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
local_dir: Root directory of the dataset containing meta/advantages.json.
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
Dictionary mapping (episode_index, timestamp) tuples to advantage values,
|
|
411
|
+
or None if the file doesn't exist.
|
|
412
|
+
"""
|
|
413
|
+
if not (local_dir / ADVANTAGES_PATH).exists():
|
|
414
|
+
return None
|
|
415
|
+
advantages = load_json(local_dir / ADVANTAGES_PATH)
|
|
416
|
+
return {(int(k.split(",")[0]), float(k.split(",")[1])): v for k, v in advantages.items()}
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def write_task(task_index: int, task: dict, local_dir: Path) -> None:
|
|
420
|
+
"""Write a task entry to the tasks.jsonl file.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
task_index: Integer index of the task.
|
|
424
|
+
task: Task description dictionary.
|
|
425
|
+
local_dir: Root directory of the dataset where meta/tasks.jsonl will be written.
|
|
426
|
+
"""
|
|
427
|
+
task_dict = {
|
|
428
|
+
"task_index": task_index,
|
|
429
|
+
"task": task,
|
|
430
|
+
}
|
|
431
|
+
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
|
435
|
+
"""Load tasks from the tasks.jsonl file.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
local_dir: Root directory of the dataset containing meta/tasks.jsonl.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
Tuple of (tasks_dict, task_to_index_dict):
|
|
442
|
+
- tasks_dict: Dictionary mapping task_index to task description.
|
|
443
|
+
- task_to_index_dict: Dictionary mapping task description to task_index.
|
|
444
|
+
"""
|
|
445
|
+
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
|
446
|
+
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
|
447
|
+
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
|
448
|
+
return tasks, task_to_task_index
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def write_episode(episode: dict, local_dir: Path) -> None:
|
|
452
|
+
"""Write an episode entry to the episodes.jsonl file.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
episode: Episode dictionary containing episode_index, tasks, length, etc.
|
|
456
|
+
local_dir: Root directory of the dataset where meta/episodes.jsonl will be written.
|
|
457
|
+
"""
|
|
458
|
+
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def load_episodes(local_dir: Path) -> dict:
|
|
462
|
+
"""Load episodes from the episodes.jsonl file.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
local_dir: Root directory of the dataset containing meta/episodes.jsonl.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
Dictionary mapping episode_index to episode information dictionary.
|
|
469
|
+
"""
|
|
470
|
+
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
|
471
|
+
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path) -> None:
|
|
475
|
+
"""Write episode statistics to the episodes_stats.jsonl file.
|
|
476
|
+
|
|
477
|
+
Serializes tensors and arrays in the stats before writing.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
episode_index: Index of the episode.
|
|
481
|
+
episode_stats: Dictionary containing statistics for the episode (may contain tensors/arrays).
|
|
482
|
+
local_dir: Root directory of the dataset where meta/episodes_stats.jsonl will be written.
|
|
483
|
+
"""
|
|
484
|
+
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
|
485
|
+
# is a dictionary of stats and not an integer.
|
|
486
|
+
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
|
487
|
+
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def load_episodes_stats(local_dir: Path) -> dict:
|
|
491
|
+
"""Load episode statistics from the episodes_stats.jsonl file.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
local_dir: Root directory of the dataset containing meta/episodes_stats.jsonl.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
Dictionary mapping episode_index to statistics dictionary (with numpy arrays).
|
|
498
|
+
"""
|
|
499
|
+
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
|
500
|
+
return {
|
|
501
|
+
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
|
502
|
+
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def backward_compatible_episodes_stats(
|
|
507
|
+
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
|
508
|
+
) -> dict[str, dict[str, np.ndarray]]:
|
|
509
|
+
"""Create episode-level statistics from global statistics for backward compatibility.
|
|
510
|
+
|
|
511
|
+
In older dataset versions, statistics were stored globally rather than per-episode.
|
|
512
|
+
This function creates per-episode statistics by assigning the same global stats
|
|
513
|
+
to each episode.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
stats: Global statistics dictionary.
|
|
517
|
+
episodes: List of episode indices.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
Dictionary mapping episode_index to the same statistics dictionary.
|
|
521
|
+
"""
|
|
522
|
+
return dict.fromkeys(episodes, stats)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def load_image_as_numpy(
|
|
526
|
+
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
|
527
|
+
) -> np.ndarray:
|
|
528
|
+
"""Load an image file as a numpy array.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
fpath: Path to the image file.
|
|
532
|
+
dtype: Data type for the array. Defaults to np.float32.
|
|
533
|
+
channel_first: If True, return array in (C, H, W) format; otherwise (H, W, C).
|
|
534
|
+
Defaults to True.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
Image as numpy array. If dtype is floating point, values are normalized to [0, 1].
|
|
538
|
+
Otherwise, values are in [0, 255].
|
|
539
|
+
"""
|
|
540
|
+
img = PILImage.open(fpath).convert("RGB")
|
|
541
|
+
img_array = np.array(img, dtype=dtype)
|
|
542
|
+
if channel_first: # (H, W, C) -> (C, H, W)
|
|
543
|
+
img_array = np.transpose(img_array, (2, 0, 1))
|
|
544
|
+
if np.issubdtype(dtype, np.floating):
|
|
545
|
+
img_array /= 255.0
|
|
546
|
+
return img_array
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
|
550
|
+
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
|
551
|
+
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
|
552
|
+
a channel last representation (h w c) of uint8 type, to a torch image representation
|
|
553
|
+
with channel first (c h w) of float32 type in range [0,1].
|
|
554
|
+
"""
|
|
555
|
+
for key in items_dict:
|
|
556
|
+
first_item = items_dict[key][0]
|
|
557
|
+
if isinstance(first_item, PILImage.Image):
|
|
558
|
+
to_tensor = transforms.ToTensor()
|
|
559
|
+
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
|
560
|
+
elif first_item is None:
|
|
561
|
+
pass
|
|
562
|
+
else:
|
|
563
|
+
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
|
564
|
+
return items_dict
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def is_valid_version(version: str) -> bool:
|
|
568
|
+
"""Check if a version string is valid and can be parsed.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
version: Version string to validate.
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
True if the version string is valid, False otherwise.
|
|
575
|
+
"""
|
|
576
|
+
try:
|
|
577
|
+
packaging.version.parse(version)
|
|
578
|
+
return True
|
|
579
|
+
except packaging.version.InvalidVersion:
|
|
580
|
+
return False
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def check_version_compatibility(
|
|
584
|
+
repo_id: str,
|
|
585
|
+
version_to_check: str | packaging.version.Version,
|
|
586
|
+
current_version: str | packaging.version.Version,
|
|
587
|
+
enforce_breaking_major: bool = True,
|
|
588
|
+
) -> None:
|
|
589
|
+
"""Check compatibility between a dataset version and the current codebase version.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
repo_id: Repository ID of the dataset.
|
|
593
|
+
version_to_check: Version of the dataset to check.
|
|
594
|
+
current_version: Current codebase version.
|
|
595
|
+
enforce_breaking_major: If True, raise error for major version mismatches.
|
|
596
|
+
Defaults to True.
|
|
597
|
+
|
|
598
|
+
Raises:
|
|
599
|
+
BackwardCompatibilityError: If the dataset version is too old (major version mismatch).
|
|
600
|
+
"""
|
|
601
|
+
v_check = (
|
|
602
|
+
packaging.version.parse(version_to_check)
|
|
603
|
+
if not isinstance(version_to_check, packaging.version.Version)
|
|
604
|
+
else version_to_check
|
|
605
|
+
)
|
|
606
|
+
v_current = (
|
|
607
|
+
packaging.version.parse(current_version)
|
|
608
|
+
if not isinstance(current_version, packaging.version.Version)
|
|
609
|
+
else current_version
|
|
610
|
+
)
|
|
611
|
+
if v_check.major < v_current.major and enforce_breaking_major:
|
|
612
|
+
raise BackwardCompatibilityError(repo_id, v_check)
|
|
613
|
+
elif v_check.minor < v_current.minor:
|
|
614
|
+
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
|
618
|
+
"""Returns available valid versions (branches and tags) on given repo."""
|
|
619
|
+
api = HfApi()
|
|
620
|
+
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
|
621
|
+
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
|
622
|
+
repo_versions = []
|
|
623
|
+
for ref in repo_refs:
|
|
624
|
+
with contextlib.suppress(packaging.version.InvalidVersion):
|
|
625
|
+
repo_versions.append(packaging.version.parse(ref))
|
|
626
|
+
|
|
627
|
+
return repo_versions
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
|
631
|
+
"""
|
|
632
|
+
Returns the version if available on repo or the latest compatible one.
|
|
633
|
+
Otherwise, will throw a `CompatibilityError`.
|
|
634
|
+
"""
|
|
635
|
+
target_version = (
|
|
636
|
+
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
|
637
|
+
)
|
|
638
|
+
hub_versions = get_repo_versions(repo_id)
|
|
639
|
+
|
|
640
|
+
if not hub_versions:
|
|
641
|
+
raise RevisionNotFoundError(
|
|
642
|
+
f"""Your dataset must be tagged with a codebase version.
|
|
643
|
+
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
|
644
|
+
```python
|
|
645
|
+
from huggingface_hub import HfApi
|
|
646
|
+
|
|
647
|
+
hub_api = HfApi()
|
|
648
|
+
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
|
649
|
+
```
|
|
650
|
+
"""
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
if target_version in hub_versions:
|
|
654
|
+
return f"v{target_version}"
|
|
655
|
+
|
|
656
|
+
compatibles = [
|
|
657
|
+
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
|
658
|
+
]
|
|
659
|
+
if compatibles:
|
|
660
|
+
return_version = max(compatibles)
|
|
661
|
+
if return_version < target_version:
|
|
662
|
+
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
|
663
|
+
return f"v{return_version}"
|
|
664
|
+
|
|
665
|
+
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
|
666
|
+
if lower_major:
|
|
667
|
+
raise BackwardCompatibilityError(repo_id, max(lower_major))
|
|
668
|
+
|
|
669
|
+
upper_versions = [v for v in hub_versions if v > target_version]
|
|
670
|
+
assert len(upper_versions) > 0
|
|
671
|
+
raise ForwardCompatibilityError(repo_id, min(upper_versions))
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
|
675
|
+
"""Convert dataset features dictionary to HuggingFace Features object.
|
|
676
|
+
|
|
677
|
+
Maps feature types and shapes to appropriate HuggingFace feature types
|
|
678
|
+
(Image, Value, Sequence, Array2D, Array3D, Array4D, Array5D).
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
features: Dictionary mapping feature names to feature specifications
|
|
682
|
+
with 'dtype' and 'shape' keys.
|
|
683
|
+
|
|
684
|
+
Returns:
|
|
685
|
+
HuggingFace Features object compatible with the dataset library.
|
|
686
|
+
|
|
687
|
+
Raises:
|
|
688
|
+
ValueError: If a feature shape is not supported (more than 5 dimensions).
|
|
689
|
+
"""
|
|
690
|
+
hf_features = {}
|
|
691
|
+
for key, ft in features.items():
|
|
692
|
+
if ft["dtype"] == "video":
|
|
693
|
+
continue
|
|
694
|
+
elif ft["dtype"] == "image":
|
|
695
|
+
hf_features[key] = datasets.Image()
|
|
696
|
+
elif ft["shape"] == (1,):
|
|
697
|
+
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
|
698
|
+
elif len(ft["shape"]) == 1:
|
|
699
|
+
hf_features[key] = datasets.Sequence(
|
|
700
|
+
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
|
701
|
+
)
|
|
702
|
+
elif len(ft["shape"]) == 2:
|
|
703
|
+
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
|
704
|
+
elif len(ft["shape"]) == 3:
|
|
705
|
+
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
|
706
|
+
elif len(ft["shape"]) == 4:
|
|
707
|
+
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
|
708
|
+
elif len(ft["shape"]) == 5:
|
|
709
|
+
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
|
710
|
+
else:
|
|
711
|
+
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
|
712
|
+
|
|
713
|
+
return datasets.Features(hf_features)
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
|
717
|
+
"""Convert dataset features to policy feature format.
|
|
718
|
+
|
|
719
|
+
Maps dataset features to policy feature types (VISUAL, ENV, STATE, ACTION)
|
|
720
|
+
based on feature names and data types.
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
features: Dictionary mapping feature names to feature specifications.
|
|
724
|
+
|
|
725
|
+
Returns:
|
|
726
|
+
Dictionary mapping feature names to PolicyFeature objects.
|
|
727
|
+
|
|
728
|
+
Raises:
|
|
729
|
+
ValueError: If a visual feature doesn't have 3 dimensions.
|
|
730
|
+
"""
|
|
731
|
+
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
|
732
|
+
policy_features = {}
|
|
733
|
+
for key, ft in features.items():
|
|
734
|
+
shape = ft["shape"]
|
|
735
|
+
if ft["dtype"] in ["image", "video"]:
|
|
736
|
+
type = FeatureType.VISUAL
|
|
737
|
+
if len(shape) != 3:
|
|
738
|
+
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
|
739
|
+
elif key == "observation.environment_state":
|
|
740
|
+
type = FeatureType.ENV
|
|
741
|
+
elif key == "state":
|
|
742
|
+
type = FeatureType.STATE
|
|
743
|
+
elif key == "actions":
|
|
744
|
+
type = FeatureType.ACTION
|
|
745
|
+
else:
|
|
746
|
+
continue
|
|
747
|
+
|
|
748
|
+
policy_features[key] = PolicyFeature(
|
|
749
|
+
type=type,
|
|
750
|
+
shape=shape,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
return policy_features
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
def create_empty_dataset_info(
|
|
757
|
+
codebase_version: str,
|
|
758
|
+
fps: int,
|
|
759
|
+
robot_type: str,
|
|
760
|
+
features: dict,
|
|
761
|
+
use_videos: bool,
|
|
762
|
+
) -> dict:
|
|
763
|
+
"""Create an empty dataset info dictionary with default values.
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
codebase_version: Version of the codebase used to create the dataset.
|
|
767
|
+
fps: Frames per second used during data collection.
|
|
768
|
+
robot_type: Type of robot used (can be None).
|
|
769
|
+
features: Dictionary of feature specifications.
|
|
770
|
+
use_videos: Whether videos are used for visual modalities.
|
|
771
|
+
|
|
772
|
+
Returns:
|
|
773
|
+
Dictionary containing dataset metadata with initialized counters and paths.
|
|
774
|
+
"""
|
|
775
|
+
return {
|
|
776
|
+
"codebase_version": codebase_version,
|
|
777
|
+
"robot_type": robot_type,
|
|
778
|
+
"total_episodes": 0,
|
|
779
|
+
"total_frames": 0,
|
|
780
|
+
"total_tasks": 0,
|
|
781
|
+
"total_videos": 0,
|
|
782
|
+
"total_chunks": 0,
|
|
783
|
+
"chunks_size": DEFAULT_CHUNK_SIZE,
|
|
784
|
+
"fps": fps,
|
|
785
|
+
"splits": {},
|
|
786
|
+
"data_path": DEFAULT_PARQUET_PATH,
|
|
787
|
+
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
|
788
|
+
"features": features,
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
def get_episode_data_index(
|
|
793
|
+
episode_dicts: dict[dict], episodes: list[int] | None = None
|
|
794
|
+
) -> tuple[dict[str, torch.Tensor], dict[int, int]]:
|
|
795
|
+
"""Compute data indices for episodes in a flattened dataset.
|
|
796
|
+
|
|
797
|
+
Calculates start and end indices for each episode in a concatenated dataset,
|
|
798
|
+
and creates a mapping from episode index to position in the episodes list.
|
|
799
|
+
|
|
800
|
+
Args:
|
|
801
|
+
episode_dicts: Dictionary mapping episode_index to episode info dicts
|
|
802
|
+
containing 'length' keys.
|
|
803
|
+
episodes: Optional list of episode indices to include. If None, uses all
|
|
804
|
+
episodes from episode_dicts.
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
Tuple of (episode_data_index, ep2idx):
|
|
808
|
+
- episode_data_index: Dictionary with 'from' and 'to' tensors indicating
|
|
809
|
+
start and end indices for each episode.
|
|
810
|
+
- ep2idx: Dictionary mapping episode_index to position in the episodes list.
|
|
811
|
+
"""
|
|
812
|
+
# `episodes_dicts` are not necessarily sorted, or starting with episode_index 0.
|
|
813
|
+
episode_lengths = {edict["episode_index"]: edict["length"] for edict in episode_dicts.values()}
|
|
814
|
+
|
|
815
|
+
if episodes is None:
|
|
816
|
+
episodes = list(episode_lengths.keys())
|
|
817
|
+
|
|
818
|
+
episode_lengths = [episode_lengths[ep_idx] for ep_idx in episodes]
|
|
819
|
+
cumulative_lengths = list(accumulate(episode_lengths))
|
|
820
|
+
start = [0] + cumulative_lengths[:-1]
|
|
821
|
+
end = cumulative_lengths
|
|
822
|
+
ep2idx = {ep_idx: i for i, ep_idx in enumerate(episodes)}
|
|
823
|
+
return {"from": torch.LongTensor(start), "to": torch.LongTensor(end)}, ep2idx
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
def check_timestamps_sync(
|
|
827
|
+
timestamps: np.ndarray,
|
|
828
|
+
episode_indices: np.ndarray,
|
|
829
|
+
episode_data_index: dict[str, np.ndarray],
|
|
830
|
+
fps: int,
|
|
831
|
+
tolerance_s: float,
|
|
832
|
+
raise_value_error: bool = True,
|
|
833
|
+
) -> bool:
|
|
834
|
+
"""
|
|
835
|
+
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
|
836
|
+
to account for possible numerical error.
|
|
837
|
+
|
|
838
|
+
Args:
|
|
839
|
+
timestamps (np.ndarray): Array of timestamps in seconds.
|
|
840
|
+
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
|
841
|
+
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
|
842
|
+
which identifies indices for the end of each episode.
|
|
843
|
+
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
|
844
|
+
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
|
845
|
+
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
|
846
|
+
|
|
847
|
+
Returns:
|
|
848
|
+
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
|
849
|
+
|
|
850
|
+
Raises:
|
|
851
|
+
ValueError: If the check fails and `raise_value_error` is True.
|
|
852
|
+
"""
|
|
853
|
+
if timestamps.shape != episode_indices.shape:
|
|
854
|
+
raise ValueError(
|
|
855
|
+
"timestamps and episode_indices should have the same shape. "
|
|
856
|
+
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
# Consecutive differences
|
|
860
|
+
diffs = np.diff(timestamps)
|
|
861
|
+
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
|
862
|
+
|
|
863
|
+
# Mask to ignore differences at the boundaries between episodes
|
|
864
|
+
mask = np.ones(len(diffs), dtype=bool)
|
|
865
|
+
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
|
866
|
+
mask[ignored_diffs] = False
|
|
867
|
+
filtered_within_tolerance = within_tolerance[mask]
|
|
868
|
+
|
|
869
|
+
# Check if all remaining diffs are within tolerance
|
|
870
|
+
if not np.all(filtered_within_tolerance):
|
|
871
|
+
# Track original indices before masking
|
|
872
|
+
original_indices = np.arange(len(diffs))
|
|
873
|
+
filtered_indices = original_indices[mask]
|
|
874
|
+
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
|
875
|
+
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
|
876
|
+
|
|
877
|
+
outside_tolerances = []
|
|
878
|
+
for idx in outside_tolerance_indices:
|
|
879
|
+
entry = {
|
|
880
|
+
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
|
881
|
+
"diff": diffs[idx],
|
|
882
|
+
"episode_index": episode_indices[idx].item()
|
|
883
|
+
if hasattr(episode_indices[idx], "item")
|
|
884
|
+
else episode_indices[idx],
|
|
885
|
+
}
|
|
886
|
+
outside_tolerances.append(entry)
|
|
887
|
+
|
|
888
|
+
if raise_value_error:
|
|
889
|
+
raise ValueError(
|
|
890
|
+
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
|
891
|
+
This might be due to synchronization issues during data collection.
|
|
892
|
+
\n{pformat(outside_tolerances)}"""
|
|
893
|
+
)
|
|
894
|
+
return False
|
|
895
|
+
|
|
896
|
+
return True
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
DeltaTimestampParam = dict[str, np.ndarray]
|
|
900
|
+
DeltaTimestampInfo = tuple[DeltaTimestampParam, DeltaTimestampParam, DeltaTimestampParam, DeltaTimestampParam]
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
def get_delta_indices_soft(delta_timestamps_info: DeltaTimestampInfo, fps: int) -> DeltaTimestampParam:
|
|
904
|
+
r"""Returns soft indices (not necessarily integer) for delta timestamps based on the provided information.
|
|
905
|
+
Soft indices are computed by sampling from a normal distribution defined by the mean and standard deviation
|
|
906
|
+
and clipping the values to the specified lower and upper bounds.
|
|
907
|
+
Note: Soft indices can be converted to integer indices by either rounding or interpolation.
|
|
908
|
+
"""
|
|
909
|
+
soft_indices = {}
|
|
910
|
+
mean, std, lower, upper = delta_timestamps_info
|
|
911
|
+
for key in mean:
|
|
912
|
+
dT = np.random.normal(mean[key], std[key]).clip(lower[key], upper[key]) # noqa: N806
|
|
913
|
+
soft_indices[key] = dT * fps
|
|
914
|
+
|
|
915
|
+
return soft_indices
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
def cycle(iterable):
|
|
919
|
+
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
|
920
|
+
|
|
921
|
+
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
|
922
|
+
"""
|
|
923
|
+
iterator = iter(iterable)
|
|
924
|
+
while True:
|
|
925
|
+
try:
|
|
926
|
+
yield next(iterator)
|
|
927
|
+
except StopIteration:
|
|
928
|
+
iterator = iter(iterable)
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
|
932
|
+
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
|
933
|
+
exists before creating it.
|
|
934
|
+
"""
|
|
935
|
+
api = HfApi()
|
|
936
|
+
|
|
937
|
+
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
|
938
|
+
refs = [branch.ref for branch in branches]
|
|
939
|
+
ref = f"refs/heads/{branch}"
|
|
940
|
+
if ref in refs:
|
|
941
|
+
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
|
942
|
+
|
|
943
|
+
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
|
944
|
+
|
|
945
|
+
|
|
946
|
+
def create_lerobot_dataset_card(
|
|
947
|
+
tags: list | None = None,
|
|
948
|
+
dataset_info: dict | None = None,
|
|
949
|
+
**kwargs,
|
|
950
|
+
) -> DatasetCard:
|
|
951
|
+
"""
|
|
952
|
+
Keyword arguments will be used to replace values in `src/opentau/datasets/card_template.md`.
|
|
953
|
+
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
|
954
|
+
"""
|
|
955
|
+
card_tags = ["OpenTau"]
|
|
956
|
+
|
|
957
|
+
if tags:
|
|
958
|
+
card_tags += tags
|
|
959
|
+
if dataset_info:
|
|
960
|
+
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
|
961
|
+
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
|
962
|
+
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
|
963
|
+
card_data = DatasetCardData(
|
|
964
|
+
license=kwargs.get("license"),
|
|
965
|
+
tags=card_tags,
|
|
966
|
+
task_categories=["robotics"],
|
|
967
|
+
configs=[
|
|
968
|
+
{
|
|
969
|
+
"config_name": "default",
|
|
970
|
+
"data_files": "data/*/*.parquet",
|
|
971
|
+
}
|
|
972
|
+
],
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
card_template = (importlib.resources.files("opentau.datasets") / "card_template.md").read_text()
|
|
976
|
+
|
|
977
|
+
return DatasetCard.from_template(
|
|
978
|
+
card_data=card_data,
|
|
979
|
+
template_str=card_template,
|
|
980
|
+
**kwargs,
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
class IterableNamespace(SimpleNamespace):
|
|
985
|
+
"""
|
|
986
|
+
A namespace object that supports both dictionary-like iteration and dot notation access.
|
|
987
|
+
Automatically converts nested dictionaries into IterableNamespaces.
|
|
988
|
+
|
|
989
|
+
This class extends SimpleNamespace to provide:
|
|
990
|
+
- Dictionary-style iteration over keys
|
|
991
|
+
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
|
992
|
+
- Dictionary-like methods: items(), keys(), values()
|
|
993
|
+
- Recursive conversion of nested dictionaries
|
|
994
|
+
|
|
995
|
+
Args:
|
|
996
|
+
dictionary: Optional dictionary to initialize the namespace
|
|
997
|
+
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
|
998
|
+
|
|
999
|
+
Examples:
|
|
1000
|
+
>>> data = {"name": "Alice", "details": {"age": 25}}
|
|
1001
|
+
>>> ns = IterableNamespace(data)
|
|
1002
|
+
>>> ns.name
|
|
1003
|
+
'Alice'
|
|
1004
|
+
>>> ns.details.age
|
|
1005
|
+
25
|
|
1006
|
+
>>> list(ns.keys())
|
|
1007
|
+
['name', 'details']
|
|
1008
|
+
>>> for key, value in ns.items():
|
|
1009
|
+
... print(f"{key}: {value}")
|
|
1010
|
+
name: Alice
|
|
1011
|
+
details: IterableNamespace(age=25)
|
|
1012
|
+
"""
|
|
1013
|
+
|
|
1014
|
+
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
|
1015
|
+
super().__init__(**kwargs)
|
|
1016
|
+
if dictionary is not None:
|
|
1017
|
+
for key, value in dictionary.items():
|
|
1018
|
+
if isinstance(value, dict):
|
|
1019
|
+
setattr(self, key, IterableNamespace(value))
|
|
1020
|
+
else:
|
|
1021
|
+
setattr(self, key, value)
|
|
1022
|
+
|
|
1023
|
+
def __iter__(self) -> Iterator[str]:
|
|
1024
|
+
return iter(vars(self))
|
|
1025
|
+
|
|
1026
|
+
def __getitem__(self, key: str) -> Any:
|
|
1027
|
+
return vars(self)[key]
|
|
1028
|
+
|
|
1029
|
+
def items(self):
|
|
1030
|
+
return vars(self).items()
|
|
1031
|
+
|
|
1032
|
+
def values(self):
|
|
1033
|
+
return vars(self).values()
|
|
1034
|
+
|
|
1035
|
+
def keys(self):
|
|
1036
|
+
return vars(self).keys()
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def validate_frame(frame: dict, features: dict) -> None:
|
|
1040
|
+
"""Validate that a frame dictionary matches the expected features.
|
|
1041
|
+
|
|
1042
|
+
Checks that all required features are present, no unexpected features exist,
|
|
1043
|
+
and that feature types and shapes match the specification.
|
|
1044
|
+
|
|
1045
|
+
Args:
|
|
1046
|
+
frame: Dictionary containing frame data to validate.
|
|
1047
|
+
features: Dictionary of expected feature specifications.
|
|
1048
|
+
|
|
1049
|
+
Raises:
|
|
1050
|
+
ValueError: If the frame doesn't match the feature specifications.
|
|
1051
|
+
"""
|
|
1052
|
+
optional_features = {"timestamp"}
|
|
1053
|
+
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
|
1054
|
+
actual_features = set(frame.keys())
|
|
1055
|
+
|
|
1056
|
+
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
|
1057
|
+
|
|
1058
|
+
if "task" in frame:
|
|
1059
|
+
error_message += validate_feature_string("task", frame["task"])
|
|
1060
|
+
|
|
1061
|
+
common_features = actual_features & (expected_features | optional_features)
|
|
1062
|
+
for name in common_features - {"task"}:
|
|
1063
|
+
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
|
1064
|
+
|
|
1065
|
+
if error_message:
|
|
1066
|
+
raise ValueError(error_message)
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def validate_features_presence(
|
|
1070
|
+
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
|
1071
|
+
) -> str:
|
|
1072
|
+
"""Validate that required features are present and no unexpected features exist.
|
|
1073
|
+
|
|
1074
|
+
Args:
|
|
1075
|
+
actual_features: Set of feature names actually present.
|
|
1076
|
+
expected_features: Set of feature names that must be present.
|
|
1077
|
+
optional_features: Set of feature names that may be present but aren't required.
|
|
1078
|
+
|
|
1079
|
+
Returns:
|
|
1080
|
+
Error message string (empty if validation passes).
|
|
1081
|
+
"""
|
|
1082
|
+
error_message = ""
|
|
1083
|
+
missing_features = expected_features - actual_features
|
|
1084
|
+
extra_features = actual_features - (expected_features | optional_features)
|
|
1085
|
+
|
|
1086
|
+
if missing_features or extra_features:
|
|
1087
|
+
error_message += "Feature mismatch in `frame` dictionary:\n"
|
|
1088
|
+
if missing_features:
|
|
1089
|
+
error_message += f"Missing features: {missing_features}\n"
|
|
1090
|
+
if extra_features:
|
|
1091
|
+
error_message += f"Extra features: {extra_features}\n"
|
|
1092
|
+
|
|
1093
|
+
return error_message
|
|
1094
|
+
|
|
1095
|
+
|
|
1096
|
+
def validate_feature_dtype_and_shape(
|
|
1097
|
+
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
|
1098
|
+
) -> str:
|
|
1099
|
+
"""Validate that a feature value matches its expected dtype and shape.
|
|
1100
|
+
|
|
1101
|
+
Routes to appropriate validation function based on feature type.
|
|
1102
|
+
|
|
1103
|
+
Args:
|
|
1104
|
+
name: Name of the feature being validated.
|
|
1105
|
+
feature: Feature specification dictionary with 'dtype' and 'shape' keys.
|
|
1106
|
+
value: Actual value to validate.
|
|
1107
|
+
|
|
1108
|
+
Returns:
|
|
1109
|
+
Error message string (empty if validation passes).
|
|
1110
|
+
|
|
1111
|
+
Raises:
|
|
1112
|
+
NotImplementedError: If the feature dtype is not supported.
|
|
1113
|
+
"""
|
|
1114
|
+
expected_dtype = feature["dtype"]
|
|
1115
|
+
expected_shape = feature["shape"]
|
|
1116
|
+
if is_valid_numpy_dtype_string(expected_dtype):
|
|
1117
|
+
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
|
1118
|
+
elif expected_dtype in ["image", "video"]:
|
|
1119
|
+
return validate_feature_image_or_video(name, expected_shape, value)
|
|
1120
|
+
elif expected_dtype == "string":
|
|
1121
|
+
return validate_feature_string(name, value)
|
|
1122
|
+
else:
|
|
1123
|
+
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
|
1124
|
+
|
|
1125
|
+
|
|
1126
|
+
def validate_feature_numpy_array(
|
|
1127
|
+
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
|
1128
|
+
) -> str:
|
|
1129
|
+
"""Validate that a numpy array feature matches expected dtype and shape.
|
|
1130
|
+
|
|
1131
|
+
Args:
|
|
1132
|
+
name: Name of the feature being validated.
|
|
1133
|
+
expected_dtype: Expected numpy dtype as a string.
|
|
1134
|
+
expected_shape: Expected shape as a list of integers.
|
|
1135
|
+
value: Actual numpy array to validate.
|
|
1136
|
+
|
|
1137
|
+
Returns:
|
|
1138
|
+
Error message string (empty if validation passes).
|
|
1139
|
+
"""
|
|
1140
|
+
error_message = ""
|
|
1141
|
+
if isinstance(value, np.ndarray):
|
|
1142
|
+
actual_dtype = value.dtype
|
|
1143
|
+
actual_shape = value.shape
|
|
1144
|
+
|
|
1145
|
+
if actual_dtype != np.dtype(expected_dtype):
|
|
1146
|
+
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
|
1147
|
+
|
|
1148
|
+
if actual_shape != expected_shape:
|
|
1149
|
+
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
|
1150
|
+
else:
|
|
1151
|
+
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
|
1152
|
+
|
|
1153
|
+
return error_message
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
def validate_feature_image_or_video(
|
|
1157
|
+
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
|
1158
|
+
) -> str:
|
|
1159
|
+
"""Validate that an image or video feature matches expected shape.
|
|
1160
|
+
|
|
1161
|
+
Supports both channel-first (C, H, W) and channel-last (H, W, C) formats.
|
|
1162
|
+
|
|
1163
|
+
Args:
|
|
1164
|
+
name: Name of the feature being validated.
|
|
1165
|
+
expected_shape: Expected shape as [C, H, W].
|
|
1166
|
+
value: Actual image/video value (PIL Image or numpy array).
|
|
1167
|
+
|
|
1168
|
+
Returns:
|
|
1169
|
+
Error message string (empty if validation passes).
|
|
1170
|
+
|
|
1171
|
+
Note:
|
|
1172
|
+
Pixel value range validation ([0,1] for float, [0,255] for uint8) is
|
|
1173
|
+
performed by the image writer threads, not here.
|
|
1174
|
+
"""
|
|
1175
|
+
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
|
1176
|
+
error_message = ""
|
|
1177
|
+
if isinstance(value, np.ndarray):
|
|
1178
|
+
actual_shape = value.shape
|
|
1179
|
+
c, h, w = expected_shape
|
|
1180
|
+
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
|
1181
|
+
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
|
1182
|
+
elif isinstance(value, PILImage.Image):
|
|
1183
|
+
pass
|
|
1184
|
+
else:
|
|
1185
|
+
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
|
1186
|
+
|
|
1187
|
+
return error_message
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def validate_feature_string(name: str, value: str) -> str:
|
|
1191
|
+
"""Validate that a feature value is a string.
|
|
1192
|
+
|
|
1193
|
+
Args:
|
|
1194
|
+
name: Name of the feature being validated.
|
|
1195
|
+
value: Actual value to validate.
|
|
1196
|
+
|
|
1197
|
+
Returns:
|
|
1198
|
+
Error message string (empty if validation passes).
|
|
1199
|
+
"""
|
|
1200
|
+
if not isinstance(value, str):
|
|
1201
|
+
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
|
1202
|
+
return ""
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
|
1206
|
+
"""Validate that an episode buffer is properly formatted.
|
|
1207
|
+
|
|
1208
|
+
Checks that required keys exist, episode_index matches total_episodes,
|
|
1209
|
+
buffer is not empty, and all features are present.
|
|
1210
|
+
|
|
1211
|
+
Args:
|
|
1212
|
+
episode_buffer: Dictionary containing episode data to validate.
|
|
1213
|
+
total_episodes: Total number of episodes already in the dataset.
|
|
1214
|
+
features: Dictionary of expected feature specifications.
|
|
1215
|
+
|
|
1216
|
+
Raises:
|
|
1217
|
+
ValueError: If the buffer is missing required keys, is empty, or has
|
|
1218
|
+
mismatched features.
|
|
1219
|
+
NotImplementedError: If episode_index doesn't match total_episodes.
|
|
1220
|
+
"""
|
|
1221
|
+
if "size" not in episode_buffer:
|
|
1222
|
+
raise ValueError("size key not found in episode_buffer")
|
|
1223
|
+
|
|
1224
|
+
if "task" not in episode_buffer:
|
|
1225
|
+
raise ValueError("task key not found in episode_buffer")
|
|
1226
|
+
|
|
1227
|
+
if episode_buffer["episode_index"] != total_episodes:
|
|
1228
|
+
# TODO(aliberts): Add option to use existing episode_index
|
|
1229
|
+
raise NotImplementedError(
|
|
1230
|
+
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
|
1231
|
+
"match the total number of episodes already in the dataset. This is not supported for now."
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
if episode_buffer["size"] == 0:
|
|
1235
|
+
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
|
1236
|
+
|
|
1237
|
+
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
|
1238
|
+
if not buffer_keys == set(features):
|
|
1239
|
+
raise ValueError(
|
|
1240
|
+
f"Features from `episode_buffer` don't match the ones in `features`."
|
|
1241
|
+
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
|
1242
|
+
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
|
1243
|
+
)
|