robotdataset 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,30 @@
1
+ """RLData - A Python package for robot learning dataset handling.
2
+
3
+ This package provides utilities for loading and handling robot learning datasets,
4
+ with support for the OXE (Open X-Embodiment) dataset collection from Google Cloud
5
+ and HuggingFace datasets.
6
+ """
7
+
8
+ from robotdataset.oxe_dataset import (
9
+ OXEDataset,
10
+ dataset2path,
11
+ list_datasets,
12
+ validate_dataset_name,
13
+ TemporalSampler,
14
+ )
15
+ from robotdataset.table30v2_dataset import Table30v2Dataset
16
+ from robotdataset.utils import batchViz, itemViz
17
+
18
+ __all__ = [
19
+ 'OXEDataset',
20
+ 'Table30v2Dataset',
21
+ 'dataset2path',
22
+ 'list_datasets',
23
+ 'validate_dataset_name',
24
+ 'TemporalSampler',
25
+ 'batchViz',
26
+ 'itemViz',
27
+ ]
28
+
29
+ __version__ = '0.1.0'
30
+ __author__ = 'Robotics Action Group'
@@ -0,0 +1,16 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+
8
+ def _get_cache_dir(override: Optional[str] = None) -> Path:
9
+ """Return the root cache directory.
10
+
11
+ Priority: override argument → ROBOTDATASET_CACHE env var → ~/.cache/robotdataset
12
+ """
13
+ if override is not None:
14
+ return Path(override)
15
+ env = os.environ.get("ROBOTDATASET_CACHE")
16
+ return Path(env) if env else Path.home() / ".cache" / "robotdataset"
File without changes
@@ -0,0 +1,248 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from robotdataset.oxe.utils import infer_kind
10
+
11
+
12
+ # Column names used by different robotics HF dataset formats to identify episodes
13
+ _EPISODE_ID_COLUMNS = ("episode_index", "episode_id", "traj_id", "trajectory_id")
14
+
15
+ # Column names used to identify tasks within a dataset
16
+ _TASK_ID_COLUMNS = ("task_index", "task_id", "task")
17
+
18
+ # Columns that are dataset bookkeeping metadata, not step observations
19
+ _META_COLUMNS = frozenset({
20
+ "episode_index", "episode_id", "frame_index", "timestamp",
21
+ "task_index", "task_id", "index", "traj_id", "trajectory_id",
22
+ })
23
+
24
+
25
+ def _require_datasets() -> Any:
26
+ try:
27
+ import datasets
28
+ return datasets
29
+ except ImportError:
30
+ raise RuntimeError(
31
+ "Table30v2Dataset requires the 'datasets' package. "
32
+ "Install it with: pip install 'robotdataset[hf]'"
33
+ )
34
+
35
+
36
+ def _episode_id_column(features: Any) -> str:
37
+ """Find the column name used to identify episodes in this dataset."""
38
+ for col in _EPISODE_ID_COLUMNS:
39
+ if col in features:
40
+ return col
41
+ raise ValueError(
42
+ f"No episode ID column found. Expected one of: {_EPISODE_ID_COLUMNS}. "
43
+ f"Got: {list(features.keys())}"
44
+ )
45
+
46
+
47
+ def _task_id_column(features: Any) -> Optional[str]:
48
+ """Return the column name used for task IDs, or None if not present."""
49
+ for col in _TASK_ID_COLUMNS:
50
+ if col in features:
51
+ return col
52
+ return None
53
+
54
+
55
+ class _FilteredDataset:
56
+ """Lightweight filtered view of an HF dataset.
57
+
58
+ Supports iteration and column-list access (``dataset[column_name]``),
59
+ which is all that ``get_episode_ids`` and ``build_missing_episodes``
60
+ require. Compatible with both real HF Dataset objects and test fakes.
61
+ """
62
+
63
+ def __init__(self, base: Any, keep_task_ids: frozenset) -> None:
64
+ col = _task_id_column(base.features)
65
+ self._rows: List[Dict[str, Any]] = [
66
+ dict(row) for row in base if row[col] in keep_task_ids
67
+ ]
68
+ self.features = base.features
69
+
70
+ def __iter__(self):
71
+ return iter(self._rows)
72
+
73
+ def __getitem__(self, key: Any) -> Any:
74
+ if isinstance(key, str):
75
+ return [r[key] for r in self._rows]
76
+ return self._rows[key]
77
+
78
+ def __len__(self) -> int:
79
+ return len(self._rows)
80
+
81
+
82
+ def filter_by_tasks(hf_dataset: Any, tasks: List[int]) -> "_FilteredDataset":
83
+ """Return a view of ``hf_dataset`` restricted to rows from the given task IDs.
84
+
85
+ The returned object supports the same iteration and column-access interface
86
+ as the input dataset, so it can be used as a drop-in replacement everywhere
87
+ ``hf_dataset`` is consumed.
88
+
89
+ Args:
90
+ hf_dataset: A loaded HuggingFace Dataset (or compatible object).
91
+ tasks: Task IDs to keep. Rows whose task column value is not in this
92
+ list are excluded.
93
+
94
+ Raises:
95
+ ValueError: If no task ID column is found in ``hf_dataset.features``.
96
+ """
97
+ col = _task_id_column(hf_dataset.features)
98
+ if col is None:
99
+ raise ValueError(
100
+ f"Task filtering requested but no task ID column found in this dataset. "
101
+ f"Expected one of: {_TASK_ID_COLUMNS}. "
102
+ f"Got: {list(hf_dataset.features.keys())}"
103
+ )
104
+ return _FilteredDataset(hf_dataset, frozenset(tasks))
105
+
106
+
107
+ def load_hf_dataset(dataset_name: str, split: str, cache_dir: Optional[Path] = None) -> Any:
108
+ """Load a HuggingFace dataset for the given split.
109
+
110
+ The datasets library handles download and caching internally.
111
+ ``cache_dir`` is passed as the HF cache root so dataset files are
112
+ co-located with the robotdataset TED memmaps.
113
+ """
114
+ datasets = _require_datasets()
115
+ kwargs: Dict[str, Any] = {}
116
+ if cache_dir is not None:
117
+ kwargs["cache_dir"] = str(cache_dir / "hf_cache")
118
+ return datasets.load_dataset(dataset_name, split=split, **kwargs)
119
+
120
+
121
+ def get_episode_ids(hf_dataset: Any) -> List[int]:
122
+ """Return sorted unique episode IDs present in the dataset."""
123
+ col = _episode_id_column(hf_dataset.features)
124
+ return sorted(set(hf_dataset[col]))
125
+
126
+
127
+ def _convert_leaf(val: Any) -> Any:
128
+ """Convert a single HF dataset leaf value to a torch.Tensor or str."""
129
+ # PIL Image → uint8 numpy → torch
130
+ try:
131
+ from PIL.Image import Image as PILImage
132
+ if isinstance(val, PILImage):
133
+ return torch.from_numpy(np.ascontiguousarray(np.array(val, dtype=np.uint8)))
134
+ except ImportError:
135
+ pass
136
+
137
+ if isinstance(val, torch.Tensor):
138
+ return val
139
+
140
+ if isinstance(val, np.ndarray):
141
+ if val.dtype.kind in {"S", "U", "O"}:
142
+ return val.tolist() if val.ndim > 0 else str(val.flat[0])
143
+ return torch.from_numpy(np.ascontiguousarray(val))
144
+
145
+ if isinstance(val, np.generic):
146
+ if val.dtype.kind in {"S", "U", "O"}:
147
+ item = val.item()
148
+ return item.decode("utf-8") if isinstance(item, bytes) else str(item)
149
+ return torch.tensor(val.item())
150
+
151
+ if isinstance(val, list) and val and isinstance(val[0], (int, float)):
152
+ try:
153
+ return torch.tensor(val, dtype=torch.float32)
154
+ except Exception:
155
+ pass
156
+
157
+ # bool must come before int since bool is a subclass of int
158
+ if isinstance(val, bool):
159
+ return torch.tensor(val)
160
+ if isinstance(val, (int, float)):
161
+ return torch.tensor(val)
162
+
163
+ if isinstance(val, (bytes, bytearray)):
164
+ try:
165
+ return val.decode("utf-8")
166
+ except Exception:
167
+ return val
168
+
169
+ return val
170
+
171
+
172
+ def _to_nested(flat: Dict[str, Any]) -> Dict[str, Any]:
173
+ """Reconstruct a nested dict from dot-separated flat keys.
174
+
175
+ E.g. {"observation.image": v1, "observation.state": v2}
176
+ → {"observation": {"image": v1, "state": v2}}
177
+ """
178
+ nested: Dict[str, Any] = {}
179
+ for key, val in flat.items():
180
+ parts = key.split(".")
181
+ d = nested
182
+ for part in parts[:-1]:
183
+ d = d.setdefault(part, {})
184
+ d[parts[-1]] = val
185
+ return nested
186
+
187
+
188
+ def hf_episode_to_oxe_format(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
189
+ """Convert a list of HF dataset rows to an OXE-style episode dict.
190
+
191
+ The result is compatible with ``episode_to_ted_steps()``:
192
+ {"steps": [{observation: {...}, action: T, reward: T, is_last: T, ...}, ...]}
193
+
194
+ Handles:
195
+ - Dot-separated column names (LeRobot style) → nested dicts
196
+ - ``next.*`` sub-keys → top-level ``reward`` / ``is_last`` / ``is_terminal``
197
+ - PIL images → uint8 HWC tensors
198
+ - Numpy arrays / Python lists → torch tensors
199
+ """
200
+ steps = []
201
+ for row in rows:
202
+ converted = {k: _convert_leaf(v) for k, v in row.items()}
203
+ nested = _to_nested(converted)
204
+
205
+ # Lift LeRobot-style "next" sub-dict to OXE top-level field names
206
+ next_dict = nested.pop("next", {})
207
+ if "reward" in next_dict:
208
+ nested.setdefault("reward", next_dict["reward"])
209
+ if "done" in next_dict:
210
+ nested.setdefault("is_last", next_dict["done"])
211
+ if "terminated" in next_dict:
212
+ nested.setdefault("is_terminal", next_dict["terminated"])
213
+
214
+ # Remove bookkeeping metadata columns
215
+ for col in _META_COLUMNS:
216
+ nested.pop(col, None)
217
+
218
+ steps.append(nested)
219
+ return {"steps": steps}
220
+
221
+
222
+ def infer_modalities_from_storage(td: Any) -> Dict[str, Dict[str, Any]]:
223
+ """Return ``{path: spec_dict}`` by inspecting a loaded TED TensorDict.
224
+
225
+ Infers modalities directly from the built storage rather than the source
226
+ dataset schema, so it works regardless of the upstream data format.
227
+ """
228
+ modalities: Dict[str, Dict[str, Any]] = {}
229
+
230
+ def _visit(t: Any, prefix: str = "") -> None:
231
+ for key in t.keys():
232
+ path = f"{prefix}/{key}" if prefix else str(key)
233
+ val = t.get(key)
234
+ if hasattr(val, "keys"):
235
+ _visit(val, path)
236
+ elif isinstance(val, torch.Tensor):
237
+ dtype_str = str(val.dtype).replace("torch.", "")
238
+ shape = tuple(int(d) for d in val.shape[1:])
239
+ modalities[path] = {
240
+ "path": path,
241
+ "kind": infer_kind(path),
242
+ "dtype": dtype_str,
243
+ "shape": shape,
244
+ "source": "storage",
245
+ }
246
+
247
+ _visit(td)
248
+ return modalities
@@ -0,0 +1,100 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Set
6
+
7
+ import torch
8
+
9
+ from robotdataset.hf.loader import _episode_id_column, hf_episode_to_oxe_format
10
+ from robotdataset.oxe.memmap_builder import is_episode_cached
11
+ from robotdataset.oxe.utils import episode_to_ted_steps
12
+
13
+ try:
14
+ from tqdm.auto import tqdm as _tqdm_cls
15
+ except ImportError:
16
+ _tqdm_cls = None # type: ignore[assignment]
17
+
18
+ _EPISODE_SENTINEL = "_steps.json"
19
+
20
+
21
+ def _build_one_hf_episode(
22
+ episode_dict: Dict[str, Any],
23
+ episode_id: int,
24
+ episode_dir: Path,
25
+ ) -> int:
26
+ """Convert one HF episode to TED steps and memmap it to ``episode_dir``.
27
+
28
+ Returns the number of steps written.
29
+ """
30
+ steps = episode_to_ted_steps(episode_dict, episode_id, tf_tensor_types=())
31
+ if not steps:
32
+ return 0
33
+ episode_dir.mkdir(parents=True, exist_ok=True)
34
+ td = torch.stack(steps)
35
+ td.memmap_(str(episode_dir))
36
+ n_steps = len(td)
37
+ (episode_dir / _EPISODE_SENTINEL).write_text(json.dumps({"n_steps": n_steps}))
38
+ return n_steps
39
+
40
+
41
+ def build_missing_episodes(
42
+ hf_dataset: Any,
43
+ episodes_dir: Path,
44
+ missing: List[int],
45
+ ) -> None:
46
+ """Convert and cache only the episodes in ``missing`` (episode IDs).
47
+
48
+ Streams the HF dataset once, collecting rows for all missing episode IDs,
49
+ then converts and memmaps each episode. Already-cached episodes are
50
+ skipped (safe to call if a previous run was interrupted).
51
+
52
+ Args:
53
+ hf_dataset: Loaded HuggingFace Dataset object.
54
+ episodes_dir: Root directory for per-episode memmaps.
55
+ missing: Episode IDs that are not yet cached.
56
+ """
57
+ if not missing:
58
+ return
59
+
60
+ missing_set: Set[int] = set(missing)
61
+ col = _episode_id_column(hf_dataset.features)
62
+ has_frame_index = "frame_index" in hf_dataset.features
63
+
64
+ # Single pass over the dataset to collect rows for all missing episodes
65
+ episode_rows: Dict[int, List[Any]] = {eid: [] for eid in missing_set}
66
+ for row in hf_dataset:
67
+ eid = row[col]
68
+ if eid in missing_set:
69
+ episode_rows[eid].append(dict(row))
70
+
71
+ if has_frame_index:
72
+ for rows in episode_rows.values():
73
+ rows.sort(key=lambda r: r["frame_index"])
74
+
75
+ pbar = (
76
+ _tqdm_cls(
77
+ total=len(missing_set),
78
+ desc="Converting episodes to TED",
79
+ unit="ep",
80
+ dynamic_ncols=True,
81
+ )
82
+ if _tqdm_cls is not None
83
+ else None
84
+ )
85
+
86
+ for episode_id in sorted(missing_set):
87
+ episode_dir = episodes_dir / str(episode_id)
88
+ if is_episode_cached(episode_dir):
89
+ if pbar is not None:
90
+ pbar.update(1)
91
+ continue
92
+ rows = episode_rows.get(episode_id, [])
93
+ if rows:
94
+ episode_dict = hf_episode_to_oxe_format(rows)
95
+ _build_one_hf_episode(episode_dict, episode_id, episode_dir)
96
+ if pbar is not None:
97
+ pbar.update(1)
98
+
99
+ if pbar is not None:
100
+ pbar.close()
@@ -0,0 +1,26 @@
1
+ """Internal helpers for OXE dataset loading."""
2
+
3
+ from robotdataset.oxe.bucket import discover_dataset_versions, discover_datasets_from_bucket
4
+ from robotdataset.oxe.temporal_sampler import TemporalSampler
5
+ from robotdataset.oxe.utils import (
6
+ ModalitySpec,
7
+ flatten_structure,
8
+ infer_kind,
9
+ latest_version,
10
+ normalize_version_key,
11
+ shape_and_dtype,
12
+ tf_to_torch,
13
+ )
14
+
15
+ __all__ = [
16
+ "ModalitySpec",
17
+ "normalize_version_key",
18
+ "latest_version",
19
+ "tf_to_torch",
20
+ "infer_kind",
21
+ "shape_and_dtype",
22
+ "flatten_structure",
23
+ "discover_dataset_versions",
24
+ "discover_datasets_from_bucket",
25
+ "TemporalSampler",
26
+ ]
@@ -0,0 +1,113 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from typing import Dict, List, Optional
6
+
7
+ from robotdataset.oxe.utils import latest_version
8
+
9
+ # Parallelism for the full-bucket scan (I/O bound, so threads are fine).
10
+ _SCAN_WORKERS = 32
11
+
12
+
13
+ def _safe_gfile_listdir(tf_module: Optional[object], path: str) -> List[str]:
14
+ if tf_module is None:
15
+ return []
16
+ try:
17
+ return tf_module.io.gfile.listdir(path)
18
+ except Exception as exc:
19
+ warnings.warn(
20
+ f"Could not list GCS path '{path}': {exc}. "
21
+ "Dataset discovery will return no results.",
22
+ stacklevel=3,
23
+ )
24
+ return []
25
+
26
+
27
+ def _safe_gfile_isdir(tf_module: Optional[object], path: str) -> bool:
28
+ if tf_module is None:
29
+ return False
30
+ try:
31
+ return tf_module.io.gfile.isdir(path)
32
+ except Exception:
33
+ return False
34
+
35
+
36
+ def _join(base: str, name: str) -> str:
37
+ """Join a GCS base URL and an entry name, stripping any stray slashes."""
38
+ return f"{base.rstrip('/')}/{name.strip('/')}"
39
+
40
+
41
+ def _discover_one_dataset(
42
+ tf_module: object, bucket_url: str, raw_name: str
43
+ ) -> Optional[tuple[str, Dict[str, str]]]:
44
+ """Return (dataset_name, {version: path}) for a single bucket entry, or None."""
45
+ dataset_name = raw_name.strip("/")
46
+ dataset_root = _join(bucket_url, dataset_name)
47
+ if not _safe_gfile_isdir(tf_module, dataset_root):
48
+ return None
49
+
50
+ entries = _safe_gfile_listdir(tf_module, dataset_root)
51
+
52
+ def _is_version_dir(entry: str) -> bool:
53
+ return _safe_gfile_isdir(tf_module, _join(dataset_root, entry))
54
+
55
+ with ThreadPoolExecutor(max_workers=_SCAN_WORKERS) as pool:
56
+ flags = list(pool.map(_is_version_dir, entries))
57
+
58
+ versions = [e.strip("/") for e, is_dir in zip(entries, flags) if is_dir]
59
+
60
+ if versions:
61
+ return dataset_name, {v: _join(dataset_root, v) for v in versions}
62
+ return dataset_name, {"": dataset_root}
63
+
64
+
65
+ def discover_dataset_versions(
66
+ tf_module: Optional[object], bucket_url: str, dataset_name: str
67
+ ) -> Dict[str, str]:
68
+ """Return {version: gcs_path} for a single dataset by direct path lookup."""
69
+ if tf_module is None:
70
+ return {}
71
+ dataset_root = _join(bucket_url, dataset_name)
72
+ entries = _safe_gfile_listdir(tf_module, dataset_root)
73
+
74
+ def _is_dir(entry: str) -> bool:
75
+ return _safe_gfile_isdir(tf_module, _join(dataset_root, entry))
76
+
77
+ with ThreadPoolExecutor(max_workers=_SCAN_WORKERS) as pool:
78
+ flags = list(pool.map(_is_dir, entries))
79
+
80
+ versions = [e.strip("/") for e, is_dir in zip(entries, flags) if is_dir]
81
+ if versions:
82
+ return {v: _join(dataset_root, v) for v in versions}
83
+ if _safe_gfile_isdir(tf_module, dataset_root):
84
+ return {"": dataset_root}
85
+ return {}
86
+
87
+
88
+ def discover_datasets_from_bucket(
89
+ tf_module: Optional[object], bucket_url: str
90
+ ) -> Dict[str, Dict[str, str]]:
91
+ """Return {dataset_name: {version: gcs_path}} for every dataset in the bucket.
92
+
93
+ The per-dataset isdir/listdir calls are issued in parallel so the full scan
94
+ completes in roughly one round-trip worth of latency rather than O(N).
95
+ """
96
+ if tf_module is None:
97
+ return {}
98
+
99
+ raw_names = _safe_gfile_listdir(tf_module, bucket_url)
100
+
101
+ discovered: Dict[str, Dict[str, str]] = {}
102
+ with ThreadPoolExecutor(max_workers=_SCAN_WORKERS) as pool:
103
+ futures = {
104
+ pool.submit(_discover_one_dataset, tf_module, bucket_url, name): name
105
+ for name in raw_names
106
+ }
107
+ for future in as_completed(futures):
108
+ result = future.result()
109
+ if result is not None:
110
+ name, versions = result
111
+ discovered[name] = versions
112
+
113
+ return discovered