opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1243 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Utility functions for dataset management, I/O, and validation.
18
+
19
+ This module provides a comprehensive set of utility functions for working with
20
+ LeRobot datasets, including file I/O operations, metadata management, data
21
+ validation, version compatibility checking, and HuggingFace Hub integration.
22
+
23
+ The module is organized into several functional areas:
24
+
25
+ * Dictionary manipulation: Flattening/unflattening nested dictionaries
26
+ * File I/O: JSON and JSONL reading/writing with automatic directory creation
27
+ * Metadata management: Loading and saving dataset info, statistics, episodes,
28
+ tasks, and advantages
29
+ * Data validation: Frame and episode buffer validation with detailed error
30
+ messages
31
+
32
+ Key Features:
33
+ * Automatic serialization: Converts tensors and arrays to JSON-compatible
34
+ formats.
35
+ * Comprehensive validation: Validates frames and episodes.
36
+ * Path management: Standard paths for dataset structure (meta/, data/).
37
+
38
+ Constants:
39
+ DEFAULT_CHUNK_SIZE: Maximum number of episodes per chunk (1000).
40
+ ADVANTAGES_PATH, INFO_PATH, EPISODES_PATH, STATS_PATH: Standard paths.
41
+
42
+ Classes:
43
+ IterableNamespace: Namespace object supporting both dictionary iteration
44
+ and dot notation access.
45
+
46
+ Functions:
47
+ Dictionary manipulation:
48
+ flatten_dict: Flatten nested dictionaries with separator-based keys.
49
+ unflatten_dict: Expand flattened keys into nested dictionaries.
50
+ serialize_dict: Convert tensors/arrays to JSON-serializable format.
51
+
52
+ File I/O:
53
+ load_json, write_json: JSON file operations.
54
+ load_jsonlines, write_jsonlines: JSONL operations.
55
+
56
+ Data validation:
57
+ validate_frame: Validate frame data against feature specifications.
58
+ validate_episode_buffer: Validate episode buffer before adding.
59
+
60
+ (Note: Truncated for brevity, apply the same flat indentation to the rest)
61
+
62
+ Example:
63
+ Load dataset metadata::
64
+
65
+ >>> info = load_info(Path("my_dataset"))
66
+ >>> stats = load_stats(Path("my_dataset"))
67
+ >>> episodes = load_episodes(Path("my_dataset"))
68
+
69
+ Validate a frame::
70
+
71
+ >>> features = {"state": {"dtype": "float32", "shape": (7,)}}
72
+ >>> frame = {"state": np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])}
73
+ >>> validate_frame(frame, features)
74
+ """
75
+
76
+ import contextlib
77
+ import importlib.resources
78
+ import json
79
+ import logging
80
+ from collections.abc import Iterator
81
+ from itertools import accumulate
82
+ from pathlib import Path
83
+ from pprint import pformat
84
+ from types import SimpleNamespace
85
+ from typing import Any
86
+
87
+ import datasets
88
+ import jsonlines
89
+ import numpy as np
90
+ import packaging.version
91
+ import torch
92
+ from datasets.table import embed_table_storage
93
+ from huggingface_hub import DatasetCard, DatasetCardData, HfApi
94
+ from huggingface_hub.errors import RevisionNotFoundError
95
+ from PIL import Image as PILImage
96
+ from torchvision import transforms
97
+
98
+ from opentau.configs.types import DictLike, FeatureType, PolicyFeature
99
+ from opentau.datasets.backward_compatibility import (
100
+ V21_MESSAGE,
101
+ BackwardCompatibilityError,
102
+ ForwardCompatibilityError,
103
+ )
104
+ from opentau.utils.utils import is_valid_numpy_dtype_string
105
+
106
+ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
107
+
108
+ ADVANTAGES_PATH = "meta/advantages.json"
109
+ INFO_PATH = "meta/info.json"
110
+ EPISODES_PATH = "meta/episodes.jsonl"
111
+ STATS_PATH = "meta/stats.json"
112
+ EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
113
+ TASKS_PATH = "meta/tasks.jsonl"
114
+
115
+ DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
116
+ DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
117
+ DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
118
+
119
+ DATASET_CARD_TEMPLATE = """
120
+ ---
121
+ # Metadata will go there
122
+ ---
123
+ This dataset was created using [OpenTau](https://github.com/TensorAuto/OpenTau).
124
+
125
+ ## {}
126
+
127
+ """
128
+
129
+ DEFAULT_FEATURES = {
130
+ "timestamp": {"dtype": "float32", "shape": (1,), "names": None},
131
+ "frame_index": {"dtype": "int64", "shape": (1,), "names": None},
132
+ "episode_index": {"dtype": "int64", "shape": (1,), "names": None},
133
+ "index": {"dtype": "int64", "shape": (1,), "names": None},
134
+ "task_index": {"dtype": "int64", "shape": (1,), "names": None},
135
+ }
136
+
137
+
138
+ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
139
+ """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
140
+
141
+ For example::
142
+
143
+ >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
144
+ >>> print(flatten_dict(dct))
145
+ {"a/b": 1, "a/c/d": 2, "e": 3}
146
+ """
147
+ items = []
148
+ for k, v in d.items():
149
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
150
+ if isinstance(v, dict):
151
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
152
+ else:
153
+ items.append((new_key, v))
154
+ return dict(items)
155
+
156
+
157
+ def unflatten_dict(d: dict, sep: str = "/") -> dict:
158
+ """Unflatten a dictionary by expanding keys with separators into nested dictionaries.
159
+
160
+ Args:
161
+ d: Dictionary with flattened keys (e.g., {"a/b": 1, "a/c/d": 2}).
162
+ sep: Separator used to split keys. Defaults to "/".
163
+
164
+ Returns:
165
+ Nested dictionary structure (e.g., {"a": {"b": 1, "c": {"d": 2}}}).
166
+
167
+ Example:
168
+ >>> dct = {"a/b": 1, "a/c/d": 2, "e": 3}
169
+ >>> print(unflatten_dict(dct))
170
+ {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
171
+ """
172
+ outdict = {}
173
+ for key, value in d.items():
174
+ parts = key.split(sep)
175
+ d = outdict
176
+ for part in parts[:-1]:
177
+ if part not in d:
178
+ d[part] = {}
179
+ d = d[part]
180
+ d[parts[-1]] = value
181
+ return outdict
182
+
183
+
184
+ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
185
+ """Get a nested item from a dictionary-like object using a flattened key.
186
+
187
+ Args:
188
+ obj: Dictionary-like object to access.
189
+ flattened_key: Flattened key path (e.g., "a/b/c").
190
+ sep: Separator used in the flattened key. Defaults to "/".
191
+
192
+ Returns:
193
+ The value at the nested path specified by the flattened key.
194
+
195
+ Example:
196
+ >>> dct = {"a": {"b": {"c": 42}}}
197
+ >>> get_nested_item(dct, "a/b/c")
198
+ 42
199
+ """
200
+ split_keys = flattened_key.split(sep)
201
+ getter = obj[split_keys[0]]
202
+ if len(split_keys) == 1:
203
+ return getter
204
+
205
+ for key in split_keys[1:]:
206
+ getter = getter[key]
207
+
208
+ return getter
209
+
210
+
211
+ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
212
+ """Serialize a dictionary containing tensors and arrays to JSON-serializable format.
213
+
214
+ Converts torch.Tensor and np.ndarray to lists, and np.generic to Python scalars.
215
+ The dictionary structure is preserved through flattening and unflattening.
216
+
217
+ Args:
218
+ stats: Dictionary containing statistics with tensor/array values.
219
+
220
+ Returns:
221
+ Dictionary with serialized (list/scalar) values in the same structure.
222
+
223
+ Raises:
224
+ NotImplementedError: If a value type is not supported for serialization.
225
+ """
226
+ serialized_dict = {}
227
+ for key, value in flatten_dict(stats).items():
228
+ if isinstance(value, (torch.Tensor, np.ndarray)):
229
+ serialized_dict[key] = value.tolist()
230
+ elif isinstance(value, np.generic):
231
+ serialized_dict[key] = value.item()
232
+ elif isinstance(value, (int, float)):
233
+ serialized_dict[key] = value
234
+ else:
235
+ raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
236
+ return unflatten_dict(serialized_dict)
237
+
238
+
239
+ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
240
+ """Embed image bytes into the dataset table before saving to parquet.
241
+
242
+ Converts the dataset to arrow format, embeds image storage, and restores
243
+ the original format.
244
+
245
+ Args:
246
+ dataset: HuggingFace dataset containing images.
247
+
248
+ Returns:
249
+ Dataset with embedded image bytes, ready for parquet serialization.
250
+ """
251
+ format = dataset.format
252
+ dataset = dataset.with_format("arrow")
253
+ dataset = dataset.map(embed_table_storage, batched=False)
254
+ dataset = dataset.with_format(**format)
255
+ return dataset
256
+
257
+
258
+ def load_json(fpath: Path) -> Any:
259
+ """Load JSON data from a file.
260
+
261
+ Args:
262
+ fpath: Path to the JSON file.
263
+
264
+ Returns:
265
+ Parsed JSON data (dict, list, or primitive type).
266
+ """
267
+ with open(fpath) as f:
268
+ return json.load(f)
269
+
270
+
271
+ def write_json(data: dict, fpath: Path) -> None:
272
+ """Write data to a JSON file.
273
+
274
+ Creates parent directories if they don't exist. Uses 4-space indentation
275
+ and allows non-ASCII characters.
276
+
277
+ Args:
278
+ data: Dictionary or other JSON-serializable data to write.
279
+ fpath: Path where the JSON file will be written.
280
+ """
281
+ fpath.parent.mkdir(exist_ok=True, parents=True)
282
+ with open(fpath, "w") as f:
283
+ json.dump(data, f, indent=4, ensure_ascii=False)
284
+
285
+
286
+ def load_jsonlines(fpath: Path) -> list[Any]:
287
+ """Load JSON Lines (JSONL) data from a file.
288
+
289
+ Args:
290
+ fpath: Path to the JSONL file.
291
+
292
+ Returns:
293
+ List of dictionaries, one per line in the file.
294
+ """
295
+ with jsonlines.open(fpath, "r") as reader:
296
+ return list(reader)
297
+
298
+
299
+ def write_jsonlines(data: dict, fpath: Path) -> None:
300
+ """Write data to a JSON Lines (JSONL) file.
301
+
302
+ Creates parent directories if they don't exist. Writes each item in the
303
+ data iterable as a separate line.
304
+
305
+ Args:
306
+ data: Iterable of dictionaries to write (one per line).
307
+ fpath: Path where the JSONL file will be written.
308
+ """
309
+ fpath.parent.mkdir(exist_ok=True, parents=True)
310
+ with jsonlines.open(fpath, "w") as writer:
311
+ writer.write_all(data)
312
+
313
+
314
+ def append_jsonlines(data: dict, fpath: Path) -> None:
315
+ """Append a single dictionary to a JSON Lines (JSONL) file.
316
+
317
+ Creates parent directories if they don't exist. Appends the data as a
318
+ new line to the existing file.
319
+
320
+ Args:
321
+ data: Dictionary to append as a new line.
322
+ fpath: Path to the JSONL file (will be created if it doesn't exist).
323
+ """
324
+ fpath.parent.mkdir(exist_ok=True, parents=True)
325
+ with jsonlines.open(fpath, "a") as writer:
326
+ writer.write(data)
327
+
328
+
329
+ def write_info(info: dict, local_dir: Path) -> None:
330
+ """Write dataset info dictionary to the standard info.json file.
331
+
332
+ Args:
333
+ info: Dataset info dictionary to write.
334
+ local_dir: Root directory of the dataset where meta/info.json will be written.
335
+ """
336
+ write_json(info, local_dir / INFO_PATH)
337
+
338
+
339
+ def load_info(local_dir: Path) -> dict:
340
+ """Load dataset info from the standard info.json file.
341
+
342
+ Converts feature shapes from lists to tuples for consistency.
343
+
344
+ Args:
345
+ local_dir: Root directory of the dataset containing meta/info.json.
346
+
347
+ Returns:
348
+ Dataset info dictionary with feature shapes as tuples.
349
+ """
350
+ info = load_json(local_dir / INFO_PATH)
351
+ for ft in info["features"].values():
352
+ ft["shape"] = tuple(ft["shape"])
353
+ return info
354
+
355
+
356
+ def write_stats(stats: dict, local_dir: Path) -> None:
357
+ """Write dataset statistics to the standard stats.json file.
358
+
359
+ Serializes tensors and arrays to JSON-compatible format before writing.
360
+
361
+ Args:
362
+ stats: Dictionary containing dataset statistics (may contain tensors/arrays).
363
+ local_dir: Root directory of the dataset where meta/stats.json will be written.
364
+ """
365
+ serialized_stats = serialize_dict(stats)
366
+ write_json(serialized_stats, local_dir / STATS_PATH)
367
+
368
+
369
+ def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
370
+ """Convert statistics dictionary values to numpy arrays.
371
+
372
+ Flattens the dictionary, converts all values to numpy arrays, then
373
+ unflattens to restore the original structure.
374
+
375
+ Args:
376
+ stats: Dictionary with statistics (values may be lists or other types).
377
+
378
+ Returns:
379
+ Dictionary with the same structure but all values as numpy arrays.
380
+ """
381
+ stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
382
+ return unflatten_dict(stats)
383
+
384
+
385
+ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
386
+ """Load dataset statistics from the standard stats.json file.
387
+
388
+ Args:
389
+ local_dir: Root directory of the dataset containing meta/stats.json.
390
+
391
+ Returns:
392
+ Dictionary with statistics as numpy arrays, or None if the file doesn't exist.
393
+ """
394
+ if not (local_dir / STATS_PATH).exists():
395
+ return None
396
+ stats = load_json(local_dir / STATS_PATH)
397
+ return cast_stats_to_numpy(stats)
398
+
399
+
400
+ def load_advantages(local_dir: Path) -> dict:
401
+ """Load advantage values from the advantages.json file.
402
+
403
+ Advantages are keyed by (episode_index, timestamp) tuples in the JSON file
404
+ as comma-separated strings, which are converted to tuple keys.
405
+
406
+ Args:
407
+ local_dir: Root directory of the dataset containing meta/advantages.json.
408
+
409
+ Returns:
410
+ Dictionary mapping (episode_index, timestamp) tuples to advantage values,
411
+ or None if the file doesn't exist.
412
+ """
413
+ if not (local_dir / ADVANTAGES_PATH).exists():
414
+ return None
415
+ advantages = load_json(local_dir / ADVANTAGES_PATH)
416
+ return {(int(k.split(",")[0]), float(k.split(",")[1])): v for k, v in advantages.items()}
417
+
418
+
419
+ def write_task(task_index: int, task: dict, local_dir: Path) -> None:
420
+ """Write a task entry to the tasks.jsonl file.
421
+
422
+ Args:
423
+ task_index: Integer index of the task.
424
+ task: Task description dictionary.
425
+ local_dir: Root directory of the dataset where meta/tasks.jsonl will be written.
426
+ """
427
+ task_dict = {
428
+ "task_index": task_index,
429
+ "task": task,
430
+ }
431
+ append_jsonlines(task_dict, local_dir / TASKS_PATH)
432
+
433
+
434
+ def load_tasks(local_dir: Path) -> tuple[dict, dict]:
435
+ """Load tasks from the tasks.jsonl file.
436
+
437
+ Args:
438
+ local_dir: Root directory of the dataset containing meta/tasks.jsonl.
439
+
440
+ Returns:
441
+ Tuple of (tasks_dict, task_to_index_dict):
442
+ - tasks_dict: Dictionary mapping task_index to task description.
443
+ - task_to_index_dict: Dictionary mapping task description to task_index.
444
+ """
445
+ tasks = load_jsonlines(local_dir / TASKS_PATH)
446
+ tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
447
+ task_to_task_index = {task: task_index for task_index, task in tasks.items()}
448
+ return tasks, task_to_task_index
449
+
450
+
451
+ def write_episode(episode: dict, local_dir: Path) -> None:
452
+ """Write an episode entry to the episodes.jsonl file.
453
+
454
+ Args:
455
+ episode: Episode dictionary containing episode_index, tasks, length, etc.
456
+ local_dir: Root directory of the dataset where meta/episodes.jsonl will be written.
457
+ """
458
+ append_jsonlines(episode, local_dir / EPISODES_PATH)
459
+
460
+
461
+ def load_episodes(local_dir: Path) -> dict:
462
+ """Load episodes from the episodes.jsonl file.
463
+
464
+ Args:
465
+ local_dir: Root directory of the dataset containing meta/episodes.jsonl.
466
+
467
+ Returns:
468
+ Dictionary mapping episode_index to episode information dictionary.
469
+ """
470
+ episodes = load_jsonlines(local_dir / EPISODES_PATH)
471
+ return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
472
+
473
+
474
+ def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path) -> None:
475
+ """Write episode statistics to the episodes_stats.jsonl file.
476
+
477
+ Serializes tensors and arrays in the stats before writing.
478
+
479
+ Args:
480
+ episode_index: Index of the episode.
481
+ episode_stats: Dictionary containing statistics for the episode (may contain tensors/arrays).
482
+ local_dir: Root directory of the dataset where meta/episodes_stats.jsonl will be written.
483
+ """
484
+ # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
485
+ # is a dictionary of stats and not an integer.
486
+ episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
487
+ append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
488
+
489
+
490
+ def load_episodes_stats(local_dir: Path) -> dict:
491
+ """Load episode statistics from the episodes_stats.jsonl file.
492
+
493
+ Args:
494
+ local_dir: Root directory of the dataset containing meta/episodes_stats.jsonl.
495
+
496
+ Returns:
497
+ Dictionary mapping episode_index to statistics dictionary (with numpy arrays).
498
+ """
499
+ episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
500
+ return {
501
+ item["episode_index"]: cast_stats_to_numpy(item["stats"])
502
+ for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
503
+ }
504
+
505
+
506
+ def backward_compatible_episodes_stats(
507
+ stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
508
+ ) -> dict[str, dict[str, np.ndarray]]:
509
+ """Create episode-level statistics from global statistics for backward compatibility.
510
+
511
+ In older dataset versions, statistics were stored globally rather than per-episode.
512
+ This function creates per-episode statistics by assigning the same global stats
513
+ to each episode.
514
+
515
+ Args:
516
+ stats: Global statistics dictionary.
517
+ episodes: List of episode indices.
518
+
519
+ Returns:
520
+ Dictionary mapping episode_index to the same statistics dictionary.
521
+ """
522
+ return dict.fromkeys(episodes, stats)
523
+
524
+
525
+ def load_image_as_numpy(
526
+ fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
527
+ ) -> np.ndarray:
528
+ """Load an image file as a numpy array.
529
+
530
+ Args:
531
+ fpath: Path to the image file.
532
+ dtype: Data type for the array. Defaults to np.float32.
533
+ channel_first: If True, return array in (C, H, W) format; otherwise (H, W, C).
534
+ Defaults to True.
535
+
536
+ Returns:
537
+ Image as numpy array. If dtype is floating point, values are normalized to [0, 1].
538
+ Otherwise, values are in [0, 255].
539
+ """
540
+ img = PILImage.open(fpath).convert("RGB")
541
+ img_array = np.array(img, dtype=dtype)
542
+ if channel_first: # (H, W, C) -> (C, H, W)
543
+ img_array = np.transpose(img_array, (2, 0, 1))
544
+ if np.issubdtype(dtype, np.floating):
545
+ img_array /= 255.0
546
+ return img_array
547
+
548
+
549
+ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
550
+ """Get a transform function that convert items from Hugging Face dataset (pyarrow)
551
+ to torch tensors. Importantly, images are converted from PIL, which corresponds to
552
+ a channel last representation (h w c) of uint8 type, to a torch image representation
553
+ with channel first (c h w) of float32 type in range [0,1].
554
+ """
555
+ for key in items_dict:
556
+ first_item = items_dict[key][0]
557
+ if isinstance(first_item, PILImage.Image):
558
+ to_tensor = transforms.ToTensor()
559
+ items_dict[key] = [to_tensor(img) for img in items_dict[key]]
560
+ elif first_item is None:
561
+ pass
562
+ else:
563
+ items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
564
+ return items_dict
565
+
566
+
567
+ def is_valid_version(version: str) -> bool:
568
+ """Check if a version string is valid and can be parsed.
569
+
570
+ Args:
571
+ version: Version string to validate.
572
+
573
+ Returns:
574
+ True if the version string is valid, False otherwise.
575
+ """
576
+ try:
577
+ packaging.version.parse(version)
578
+ return True
579
+ except packaging.version.InvalidVersion:
580
+ return False
581
+
582
+
583
+ def check_version_compatibility(
584
+ repo_id: str,
585
+ version_to_check: str | packaging.version.Version,
586
+ current_version: str | packaging.version.Version,
587
+ enforce_breaking_major: bool = True,
588
+ ) -> None:
589
+ """Check compatibility between a dataset version and the current codebase version.
590
+
591
+ Args:
592
+ repo_id: Repository ID of the dataset.
593
+ version_to_check: Version of the dataset to check.
594
+ current_version: Current codebase version.
595
+ enforce_breaking_major: If True, raise error for major version mismatches.
596
+ Defaults to True.
597
+
598
+ Raises:
599
+ BackwardCompatibilityError: If the dataset version is too old (major version mismatch).
600
+ """
601
+ v_check = (
602
+ packaging.version.parse(version_to_check)
603
+ if not isinstance(version_to_check, packaging.version.Version)
604
+ else version_to_check
605
+ )
606
+ v_current = (
607
+ packaging.version.parse(current_version)
608
+ if not isinstance(current_version, packaging.version.Version)
609
+ else current_version
610
+ )
611
+ if v_check.major < v_current.major and enforce_breaking_major:
612
+ raise BackwardCompatibilityError(repo_id, v_check)
613
+ elif v_check.minor < v_current.minor:
614
+ logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
615
+
616
+
617
+ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
618
+ """Returns available valid versions (branches and tags) on given repo."""
619
+ api = HfApi()
620
+ repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
621
+ repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
622
+ repo_versions = []
623
+ for ref in repo_refs:
624
+ with contextlib.suppress(packaging.version.InvalidVersion):
625
+ repo_versions.append(packaging.version.parse(ref))
626
+
627
+ return repo_versions
628
+
629
+
630
+ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
631
+ """
632
+ Returns the version if available on repo or the latest compatible one.
633
+ Otherwise, will throw a `CompatibilityError`.
634
+ """
635
+ target_version = (
636
+ packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
637
+ )
638
+ hub_versions = get_repo_versions(repo_id)
639
+
640
+ if not hub_versions:
641
+ raise RevisionNotFoundError(
642
+ f"""Your dataset must be tagged with a codebase version.
643
+ Assuming _version_ is the codebase_version value in the info.json, you can run this:
644
+ ```python
645
+ from huggingface_hub import HfApi
646
+
647
+ hub_api = HfApi()
648
+ hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
649
+ ```
650
+ """
651
+ )
652
+
653
+ if target_version in hub_versions:
654
+ return f"v{target_version}"
655
+
656
+ compatibles = [
657
+ v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
658
+ ]
659
+ if compatibles:
660
+ return_version = max(compatibles)
661
+ if return_version < target_version:
662
+ logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
663
+ return f"v{return_version}"
664
+
665
+ lower_major = [v for v in hub_versions if v.major < target_version.major]
666
+ if lower_major:
667
+ raise BackwardCompatibilityError(repo_id, max(lower_major))
668
+
669
+ upper_versions = [v for v in hub_versions if v > target_version]
670
+ assert len(upper_versions) > 0
671
+ raise ForwardCompatibilityError(repo_id, min(upper_versions))
672
+
673
+
674
+ def get_hf_features_from_features(features: dict) -> datasets.Features:
675
+ """Convert dataset features dictionary to HuggingFace Features object.
676
+
677
+ Maps feature types and shapes to appropriate HuggingFace feature types
678
+ (Image, Value, Sequence, Array2D, Array3D, Array4D, Array5D).
679
+
680
+ Args:
681
+ features: Dictionary mapping feature names to feature specifications
682
+ with 'dtype' and 'shape' keys.
683
+
684
+ Returns:
685
+ HuggingFace Features object compatible with the dataset library.
686
+
687
+ Raises:
688
+ ValueError: If a feature shape is not supported (more than 5 dimensions).
689
+ """
690
+ hf_features = {}
691
+ for key, ft in features.items():
692
+ if ft["dtype"] == "video":
693
+ continue
694
+ elif ft["dtype"] == "image":
695
+ hf_features[key] = datasets.Image()
696
+ elif ft["shape"] == (1,):
697
+ hf_features[key] = datasets.Value(dtype=ft["dtype"])
698
+ elif len(ft["shape"]) == 1:
699
+ hf_features[key] = datasets.Sequence(
700
+ length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
701
+ )
702
+ elif len(ft["shape"]) == 2:
703
+ hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
704
+ elif len(ft["shape"]) == 3:
705
+ hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
706
+ elif len(ft["shape"]) == 4:
707
+ hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
708
+ elif len(ft["shape"]) == 5:
709
+ hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
710
+ else:
711
+ raise ValueError(f"Corresponding feature is not valid: {ft}")
712
+
713
+ return datasets.Features(hf_features)
714
+
715
+
716
+ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
717
+ """Convert dataset features to policy feature format.
718
+
719
+ Maps dataset features to policy feature types (VISUAL, ENV, STATE, ACTION)
720
+ based on feature names and data types.
721
+
722
+ Args:
723
+ features: Dictionary mapping feature names to feature specifications.
724
+
725
+ Returns:
726
+ Dictionary mapping feature names to PolicyFeature objects.
727
+
728
+ Raises:
729
+ ValueError: If a visual feature doesn't have 3 dimensions.
730
+ """
731
+ # TODO(aliberts): Implement "type" in dataset features and simplify this
732
+ policy_features = {}
733
+ for key, ft in features.items():
734
+ shape = ft["shape"]
735
+ if ft["dtype"] in ["image", "video"]:
736
+ type = FeatureType.VISUAL
737
+ if len(shape) != 3:
738
+ raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
739
+ elif key == "observation.environment_state":
740
+ type = FeatureType.ENV
741
+ elif key == "state":
742
+ type = FeatureType.STATE
743
+ elif key == "actions":
744
+ type = FeatureType.ACTION
745
+ else:
746
+ continue
747
+
748
+ policy_features[key] = PolicyFeature(
749
+ type=type,
750
+ shape=shape,
751
+ )
752
+
753
+ return policy_features
754
+
755
+
756
+ def create_empty_dataset_info(
757
+ codebase_version: str,
758
+ fps: int,
759
+ robot_type: str,
760
+ features: dict,
761
+ use_videos: bool,
762
+ ) -> dict:
763
+ """Create an empty dataset info dictionary with default values.
764
+
765
+ Args:
766
+ codebase_version: Version of the codebase used to create the dataset.
767
+ fps: Frames per second used during data collection.
768
+ robot_type: Type of robot used (can be None).
769
+ features: Dictionary of feature specifications.
770
+ use_videos: Whether videos are used for visual modalities.
771
+
772
+ Returns:
773
+ Dictionary containing dataset metadata with initialized counters and paths.
774
+ """
775
+ return {
776
+ "codebase_version": codebase_version,
777
+ "robot_type": robot_type,
778
+ "total_episodes": 0,
779
+ "total_frames": 0,
780
+ "total_tasks": 0,
781
+ "total_videos": 0,
782
+ "total_chunks": 0,
783
+ "chunks_size": DEFAULT_CHUNK_SIZE,
784
+ "fps": fps,
785
+ "splits": {},
786
+ "data_path": DEFAULT_PARQUET_PATH,
787
+ "video_path": DEFAULT_VIDEO_PATH if use_videos else None,
788
+ "features": features,
789
+ }
790
+
791
+
792
+ def get_episode_data_index(
793
+ episode_dicts: dict[dict], episodes: list[int] | None = None
794
+ ) -> tuple[dict[str, torch.Tensor], dict[int, int]]:
795
+ """Compute data indices for episodes in a flattened dataset.
796
+
797
+ Calculates start and end indices for each episode in a concatenated dataset,
798
+ and creates a mapping from episode index to position in the episodes list.
799
+
800
+ Args:
801
+ episode_dicts: Dictionary mapping episode_index to episode info dicts
802
+ containing 'length' keys.
803
+ episodes: Optional list of episode indices to include. If None, uses all
804
+ episodes from episode_dicts.
805
+
806
+ Returns:
807
+ Tuple of (episode_data_index, ep2idx):
808
+ - episode_data_index: Dictionary with 'from' and 'to' tensors indicating
809
+ start and end indices for each episode.
810
+ - ep2idx: Dictionary mapping episode_index to position in the episodes list.
811
+ """
812
+ # `episodes_dicts` are not necessarily sorted, or starting with episode_index 0.
813
+ episode_lengths = {edict["episode_index"]: edict["length"] for edict in episode_dicts.values()}
814
+
815
+ if episodes is None:
816
+ episodes = list(episode_lengths.keys())
817
+
818
+ episode_lengths = [episode_lengths[ep_idx] for ep_idx in episodes]
819
+ cumulative_lengths = list(accumulate(episode_lengths))
820
+ start = [0] + cumulative_lengths[:-1]
821
+ end = cumulative_lengths
822
+ ep2idx = {ep_idx: i for i, ep_idx in enumerate(episodes)}
823
+ return {"from": torch.LongTensor(start), "to": torch.LongTensor(end)}, ep2idx
824
+
825
+
826
+ def check_timestamps_sync(
827
+ timestamps: np.ndarray,
828
+ episode_indices: np.ndarray,
829
+ episode_data_index: dict[str, np.ndarray],
830
+ fps: int,
831
+ tolerance_s: float,
832
+ raise_value_error: bool = True,
833
+ ) -> bool:
834
+ """
835
+ This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
836
+ to account for possible numerical error.
837
+
838
+ Args:
839
+ timestamps (np.ndarray): Array of timestamps in seconds.
840
+ episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
841
+ episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
842
+ which identifies indices for the end of each episode.
843
+ fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
844
+ tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
845
+ raise_value_error (bool): Whether to raise a ValueError if the check fails.
846
+
847
+ Returns:
848
+ bool: True if all checked timestamp differences lie within tolerance, False otherwise.
849
+
850
+ Raises:
851
+ ValueError: If the check fails and `raise_value_error` is True.
852
+ """
853
+ if timestamps.shape != episode_indices.shape:
854
+ raise ValueError(
855
+ "timestamps and episode_indices should have the same shape. "
856
+ f"Found {timestamps.shape=} and {episode_indices.shape=}."
857
+ )
858
+
859
+ # Consecutive differences
860
+ diffs = np.diff(timestamps)
861
+ within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
862
+
863
+ # Mask to ignore differences at the boundaries between episodes
864
+ mask = np.ones(len(diffs), dtype=bool)
865
+ ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
866
+ mask[ignored_diffs] = False
867
+ filtered_within_tolerance = within_tolerance[mask]
868
+
869
+ # Check if all remaining diffs are within tolerance
870
+ if not np.all(filtered_within_tolerance):
871
+ # Track original indices before masking
872
+ original_indices = np.arange(len(diffs))
873
+ filtered_indices = original_indices[mask]
874
+ outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
875
+ outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
876
+
877
+ outside_tolerances = []
878
+ for idx in outside_tolerance_indices:
879
+ entry = {
880
+ "timestamps": [timestamps[idx], timestamps[idx + 1]],
881
+ "diff": diffs[idx],
882
+ "episode_index": episode_indices[idx].item()
883
+ if hasattr(episode_indices[idx], "item")
884
+ else episode_indices[idx],
885
+ }
886
+ outside_tolerances.append(entry)
887
+
888
+ if raise_value_error:
889
+ raise ValueError(
890
+ f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
891
+ This might be due to synchronization issues during data collection.
892
+ \n{pformat(outside_tolerances)}"""
893
+ )
894
+ return False
895
+
896
+ return True
897
+
898
+
899
+ DeltaTimestampParam = dict[str, np.ndarray]
900
+ DeltaTimestampInfo = tuple[DeltaTimestampParam, DeltaTimestampParam, DeltaTimestampParam, DeltaTimestampParam]
901
+
902
+
903
+ def get_delta_indices_soft(delta_timestamps_info: DeltaTimestampInfo, fps: int) -> DeltaTimestampParam:
904
+ r"""Returns soft indices (not necessarily integer) for delta timestamps based on the provided information.
905
+ Soft indices are computed by sampling from a normal distribution defined by the mean and standard deviation
906
+ and clipping the values to the specified lower and upper bounds.
907
+ Note: Soft indices can be converted to integer indices by either rounding or interpolation.
908
+ """
909
+ soft_indices = {}
910
+ mean, std, lower, upper = delta_timestamps_info
911
+ for key in mean:
912
+ dT = np.random.normal(mean[key], std[key]).clip(lower[key], upper[key]) # noqa: N806
913
+ soft_indices[key] = dT * fps
914
+
915
+ return soft_indices
916
+
917
+
918
+ def cycle(iterable):
919
+ """The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
920
+
921
+ See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
922
+ """
923
+ iterator = iter(iterable)
924
+ while True:
925
+ try:
926
+ yield next(iterator)
927
+ except StopIteration:
928
+ iterator = iter(iterable)
929
+
930
+
931
+ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
932
+ """Create a branch on a existing Hugging Face repo. Delete the branch if it already
933
+ exists before creating it.
934
+ """
935
+ api = HfApi()
936
+
937
+ branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
938
+ refs = [branch.ref for branch in branches]
939
+ ref = f"refs/heads/{branch}"
940
+ if ref in refs:
941
+ api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
942
+
943
+ api.create_branch(repo_id, repo_type=repo_type, branch=branch)
944
+
945
+
946
+ def create_lerobot_dataset_card(
947
+ tags: list | None = None,
948
+ dataset_info: dict | None = None,
949
+ **kwargs,
950
+ ) -> DatasetCard:
951
+ """
952
+ Keyword arguments will be used to replace values in `src/opentau/datasets/card_template.md`.
953
+ Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
954
+ """
955
+ card_tags = ["OpenTau"]
956
+
957
+ if tags:
958
+ card_tags += tags
959
+ if dataset_info:
960
+ dataset_structure = "[meta/info.json](meta/info.json):\n"
961
+ dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
962
+ kwargs = {**kwargs, "dataset_structure": dataset_structure}
963
+ card_data = DatasetCardData(
964
+ license=kwargs.get("license"),
965
+ tags=card_tags,
966
+ task_categories=["robotics"],
967
+ configs=[
968
+ {
969
+ "config_name": "default",
970
+ "data_files": "data/*/*.parquet",
971
+ }
972
+ ],
973
+ )
974
+
975
+ card_template = (importlib.resources.files("opentau.datasets") / "card_template.md").read_text()
976
+
977
+ return DatasetCard.from_template(
978
+ card_data=card_data,
979
+ template_str=card_template,
980
+ **kwargs,
981
+ )
982
+
983
+
984
+ class IterableNamespace(SimpleNamespace):
985
+ """
986
+ A namespace object that supports both dictionary-like iteration and dot notation access.
987
+ Automatically converts nested dictionaries into IterableNamespaces.
988
+
989
+ This class extends SimpleNamespace to provide:
990
+ - Dictionary-style iteration over keys
991
+ - Access to items via both dot notation (obj.key) and brackets (obj["key"])
992
+ - Dictionary-like methods: items(), keys(), values()
993
+ - Recursive conversion of nested dictionaries
994
+
995
+ Args:
996
+ dictionary: Optional dictionary to initialize the namespace
997
+ **kwargs: Additional keyword arguments passed to SimpleNamespace
998
+
999
+ Examples:
1000
+ >>> data = {"name": "Alice", "details": {"age": 25}}
1001
+ >>> ns = IterableNamespace(data)
1002
+ >>> ns.name
1003
+ 'Alice'
1004
+ >>> ns.details.age
1005
+ 25
1006
+ >>> list(ns.keys())
1007
+ ['name', 'details']
1008
+ >>> for key, value in ns.items():
1009
+ ... print(f"{key}: {value}")
1010
+ name: Alice
1011
+ details: IterableNamespace(age=25)
1012
+ """
1013
+
1014
+ def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
1015
+ super().__init__(**kwargs)
1016
+ if dictionary is not None:
1017
+ for key, value in dictionary.items():
1018
+ if isinstance(value, dict):
1019
+ setattr(self, key, IterableNamespace(value))
1020
+ else:
1021
+ setattr(self, key, value)
1022
+
1023
+ def __iter__(self) -> Iterator[str]:
1024
+ return iter(vars(self))
1025
+
1026
+ def __getitem__(self, key: str) -> Any:
1027
+ return vars(self)[key]
1028
+
1029
+ def items(self):
1030
+ return vars(self).items()
1031
+
1032
+ def values(self):
1033
+ return vars(self).values()
1034
+
1035
+ def keys(self):
1036
+ return vars(self).keys()
1037
+
1038
+
1039
+ def validate_frame(frame: dict, features: dict) -> None:
1040
+ """Validate that a frame dictionary matches the expected features.
1041
+
1042
+ Checks that all required features are present, no unexpected features exist,
1043
+ and that feature types and shapes match the specification.
1044
+
1045
+ Args:
1046
+ frame: Dictionary containing frame data to validate.
1047
+ features: Dictionary of expected feature specifications.
1048
+
1049
+ Raises:
1050
+ ValueError: If the frame doesn't match the feature specifications.
1051
+ """
1052
+ optional_features = {"timestamp"}
1053
+ expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
1054
+ actual_features = set(frame.keys())
1055
+
1056
+ error_message = validate_features_presence(actual_features, expected_features, optional_features)
1057
+
1058
+ if "task" in frame:
1059
+ error_message += validate_feature_string("task", frame["task"])
1060
+
1061
+ common_features = actual_features & (expected_features | optional_features)
1062
+ for name in common_features - {"task"}:
1063
+ error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
1064
+
1065
+ if error_message:
1066
+ raise ValueError(error_message)
1067
+
1068
+
1069
+ def validate_features_presence(
1070
+ actual_features: set[str], expected_features: set[str], optional_features: set[str]
1071
+ ) -> str:
1072
+ """Validate that required features are present and no unexpected features exist.
1073
+
1074
+ Args:
1075
+ actual_features: Set of feature names actually present.
1076
+ expected_features: Set of feature names that must be present.
1077
+ optional_features: Set of feature names that may be present but aren't required.
1078
+
1079
+ Returns:
1080
+ Error message string (empty if validation passes).
1081
+ """
1082
+ error_message = ""
1083
+ missing_features = expected_features - actual_features
1084
+ extra_features = actual_features - (expected_features | optional_features)
1085
+
1086
+ if missing_features or extra_features:
1087
+ error_message += "Feature mismatch in `frame` dictionary:\n"
1088
+ if missing_features:
1089
+ error_message += f"Missing features: {missing_features}\n"
1090
+ if extra_features:
1091
+ error_message += f"Extra features: {extra_features}\n"
1092
+
1093
+ return error_message
1094
+
1095
+
1096
+ def validate_feature_dtype_and_shape(
1097
+ name: str, feature: dict, value: np.ndarray | PILImage.Image | str
1098
+ ) -> str:
1099
+ """Validate that a feature value matches its expected dtype and shape.
1100
+
1101
+ Routes to appropriate validation function based on feature type.
1102
+
1103
+ Args:
1104
+ name: Name of the feature being validated.
1105
+ feature: Feature specification dictionary with 'dtype' and 'shape' keys.
1106
+ value: Actual value to validate.
1107
+
1108
+ Returns:
1109
+ Error message string (empty if validation passes).
1110
+
1111
+ Raises:
1112
+ NotImplementedError: If the feature dtype is not supported.
1113
+ """
1114
+ expected_dtype = feature["dtype"]
1115
+ expected_shape = feature["shape"]
1116
+ if is_valid_numpy_dtype_string(expected_dtype):
1117
+ return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
1118
+ elif expected_dtype in ["image", "video"]:
1119
+ return validate_feature_image_or_video(name, expected_shape, value)
1120
+ elif expected_dtype == "string":
1121
+ return validate_feature_string(name, value)
1122
+ else:
1123
+ raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
1124
+
1125
+
1126
+ def validate_feature_numpy_array(
1127
+ name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
1128
+ ) -> str:
1129
+ """Validate that a numpy array feature matches expected dtype and shape.
1130
+
1131
+ Args:
1132
+ name: Name of the feature being validated.
1133
+ expected_dtype: Expected numpy dtype as a string.
1134
+ expected_shape: Expected shape as a list of integers.
1135
+ value: Actual numpy array to validate.
1136
+
1137
+ Returns:
1138
+ Error message string (empty if validation passes).
1139
+ """
1140
+ error_message = ""
1141
+ if isinstance(value, np.ndarray):
1142
+ actual_dtype = value.dtype
1143
+ actual_shape = value.shape
1144
+
1145
+ if actual_dtype != np.dtype(expected_dtype):
1146
+ error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
1147
+
1148
+ if actual_shape != expected_shape:
1149
+ error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
1150
+ else:
1151
+ error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
1152
+
1153
+ return error_message
1154
+
1155
+
1156
+ def validate_feature_image_or_video(
1157
+ name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
1158
+ ) -> str:
1159
+ """Validate that an image or video feature matches expected shape.
1160
+
1161
+ Supports both channel-first (C, H, W) and channel-last (H, W, C) formats.
1162
+
1163
+ Args:
1164
+ name: Name of the feature being validated.
1165
+ expected_shape: Expected shape as [C, H, W].
1166
+ value: Actual image/video value (PIL Image or numpy array).
1167
+
1168
+ Returns:
1169
+ Error message string (empty if validation passes).
1170
+
1171
+ Note:
1172
+ Pixel value range validation ([0,1] for float, [0,255] for uint8) is
1173
+ performed by the image writer threads, not here.
1174
+ """
1175
+ # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
1176
+ error_message = ""
1177
+ if isinstance(value, np.ndarray):
1178
+ actual_shape = value.shape
1179
+ c, h, w = expected_shape
1180
+ if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
1181
+ error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
1182
+ elif isinstance(value, PILImage.Image):
1183
+ pass
1184
+ else:
1185
+ error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
1186
+
1187
+ return error_message
1188
+
1189
+
1190
+ def validate_feature_string(name: str, value: str) -> str:
1191
+ """Validate that a feature value is a string.
1192
+
1193
+ Args:
1194
+ name: Name of the feature being validated.
1195
+ value: Actual value to validate.
1196
+
1197
+ Returns:
1198
+ Error message string (empty if validation passes).
1199
+ """
1200
+ if not isinstance(value, str):
1201
+ return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
1202
+ return ""
1203
+
1204
+
1205
+ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
1206
+ """Validate that an episode buffer is properly formatted.
1207
+
1208
+ Checks that required keys exist, episode_index matches total_episodes,
1209
+ buffer is not empty, and all features are present.
1210
+
1211
+ Args:
1212
+ episode_buffer: Dictionary containing episode data to validate.
1213
+ total_episodes: Total number of episodes already in the dataset.
1214
+ features: Dictionary of expected feature specifications.
1215
+
1216
+ Raises:
1217
+ ValueError: If the buffer is missing required keys, is empty, or has
1218
+ mismatched features.
1219
+ NotImplementedError: If episode_index doesn't match total_episodes.
1220
+ """
1221
+ if "size" not in episode_buffer:
1222
+ raise ValueError("size key not found in episode_buffer")
1223
+
1224
+ if "task" not in episode_buffer:
1225
+ raise ValueError("task key not found in episode_buffer")
1226
+
1227
+ if episode_buffer["episode_index"] != total_episodes:
1228
+ # TODO(aliberts): Add option to use existing episode_index
1229
+ raise NotImplementedError(
1230
+ "You might have manually provided the episode_buffer with an episode_index that doesn't "
1231
+ "match the total number of episodes already in the dataset. This is not supported for now."
1232
+ )
1233
+
1234
+ if episode_buffer["size"] == 0:
1235
+ raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
1236
+
1237
+ buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
1238
+ if not buffer_keys == set(features):
1239
+ raise ValueError(
1240
+ f"Features from `episode_buffer` don't match the ones in `features`."
1241
+ f"In episode_buffer not in features: {buffer_keys - set(features)}"
1242
+ f"In features not in episode_buffer: {set(features) - buffer_keys}"
1243
+ )