robotdataset 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- robotdataset/__init__.py +30 -0
- robotdataset/_common.py +16 -0
- robotdataset/hf/__init__.py +0 -0
- robotdataset/hf/loader.py +248 -0
- robotdataset/hf/memmap_builder.py +100 -0
- robotdataset/oxe/__init__.py +26 -0
- robotdataset/oxe/bucket.py +113 -0
- robotdataset/oxe/memmap_builder.py +173 -0
- robotdataset/oxe/temporal_sampler.py +187 -0
- robotdataset/oxe/utils.py +243 -0
- robotdataset/oxe_dataset.py +556 -0
- robotdataset/table30v2_dataset.py +237 -0
- robotdataset/utils/__init__.py +3 -0
- robotdataset/utils/visualization.py +199 -0
- robotdataset-0.1.0.dist-info/METADATA +75 -0
- robotdataset-0.1.0.dist-info/RECORD +19 -0
- robotdataset-0.1.0.dist-info/WHEEL +5 -0
- robotdataset-0.1.0.dist-info/licenses/LICENSE +339 -0
- robotdataset-0.1.0.dist-info/top_level.txt +1 -0
robotdataset/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""RLData - A Python package for robot learning dataset handling.
|
|
2
|
+
|
|
3
|
+
This package provides utilities for loading and handling robot learning datasets,
|
|
4
|
+
with support for the OXE (Open X-Embodiment) dataset collection from Google Cloud
|
|
5
|
+
and HuggingFace datasets.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from robotdataset.oxe_dataset import (
|
|
9
|
+
OXEDataset,
|
|
10
|
+
dataset2path,
|
|
11
|
+
list_datasets,
|
|
12
|
+
validate_dataset_name,
|
|
13
|
+
TemporalSampler,
|
|
14
|
+
)
|
|
15
|
+
from robotdataset.table30v2_dataset import Table30v2Dataset
|
|
16
|
+
from robotdataset.utils import batchViz, itemViz
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
'OXEDataset',
|
|
20
|
+
'Table30v2Dataset',
|
|
21
|
+
'dataset2path',
|
|
22
|
+
'list_datasets',
|
|
23
|
+
'validate_dataset_name',
|
|
24
|
+
'TemporalSampler',
|
|
25
|
+
'batchViz',
|
|
26
|
+
'itemViz',
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
__version__ = '0.1.0'
|
|
30
|
+
__author__ = 'Robotics Action Group'
|
robotdataset/_common.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _get_cache_dir(override: Optional[str] = None) -> Path:
|
|
9
|
+
"""Return the root cache directory.
|
|
10
|
+
|
|
11
|
+
Priority: override argument → ROBOTDATASET_CACHE env var → ~/.cache/robotdataset
|
|
12
|
+
"""
|
|
13
|
+
if override is not None:
|
|
14
|
+
return Path(override)
|
|
15
|
+
env = os.environ.get("ROBOTDATASET_CACHE")
|
|
16
|
+
return Path(env) if env else Path.home() / ".cache" / "robotdataset"
|
|
File without changes
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from robotdataset.oxe.utils import infer_kind
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Column names used by different robotics HF dataset formats to identify episodes
|
|
13
|
+
_EPISODE_ID_COLUMNS = ("episode_index", "episode_id", "traj_id", "trajectory_id")
|
|
14
|
+
|
|
15
|
+
# Column names used to identify tasks within a dataset
|
|
16
|
+
_TASK_ID_COLUMNS = ("task_index", "task_id", "task")
|
|
17
|
+
|
|
18
|
+
# Columns that are dataset bookkeeping metadata, not step observations
|
|
19
|
+
_META_COLUMNS = frozenset({
|
|
20
|
+
"episode_index", "episode_id", "frame_index", "timestamp",
|
|
21
|
+
"task_index", "task_id", "index", "traj_id", "trajectory_id",
|
|
22
|
+
})
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _require_datasets() -> Any:
|
|
26
|
+
try:
|
|
27
|
+
import datasets
|
|
28
|
+
return datasets
|
|
29
|
+
except ImportError:
|
|
30
|
+
raise RuntimeError(
|
|
31
|
+
"Table30v2Dataset requires the 'datasets' package. "
|
|
32
|
+
"Install it with: pip install 'robotdataset[hf]'"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _episode_id_column(features: Any) -> str:
|
|
37
|
+
"""Find the column name used to identify episodes in this dataset."""
|
|
38
|
+
for col in _EPISODE_ID_COLUMNS:
|
|
39
|
+
if col in features:
|
|
40
|
+
return col
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"No episode ID column found. Expected one of: {_EPISODE_ID_COLUMNS}. "
|
|
43
|
+
f"Got: {list(features.keys())}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _task_id_column(features: Any) -> Optional[str]:
|
|
48
|
+
"""Return the column name used for task IDs, or None if not present."""
|
|
49
|
+
for col in _TASK_ID_COLUMNS:
|
|
50
|
+
if col in features:
|
|
51
|
+
return col
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _FilteredDataset:
|
|
56
|
+
"""Lightweight filtered view of an HF dataset.
|
|
57
|
+
|
|
58
|
+
Supports iteration and column-list access (``dataset[column_name]``),
|
|
59
|
+
which is all that ``get_episode_ids`` and ``build_missing_episodes``
|
|
60
|
+
require. Compatible with both real HF Dataset objects and test fakes.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, base: Any, keep_task_ids: frozenset) -> None:
|
|
64
|
+
col = _task_id_column(base.features)
|
|
65
|
+
self._rows: List[Dict[str, Any]] = [
|
|
66
|
+
dict(row) for row in base if row[col] in keep_task_ids
|
|
67
|
+
]
|
|
68
|
+
self.features = base.features
|
|
69
|
+
|
|
70
|
+
def __iter__(self):
|
|
71
|
+
return iter(self._rows)
|
|
72
|
+
|
|
73
|
+
def __getitem__(self, key: Any) -> Any:
|
|
74
|
+
if isinstance(key, str):
|
|
75
|
+
return [r[key] for r in self._rows]
|
|
76
|
+
return self._rows[key]
|
|
77
|
+
|
|
78
|
+
def __len__(self) -> int:
|
|
79
|
+
return len(self._rows)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def filter_by_tasks(hf_dataset: Any, tasks: List[int]) -> "_FilteredDataset":
|
|
83
|
+
"""Return a view of ``hf_dataset`` restricted to rows from the given task IDs.
|
|
84
|
+
|
|
85
|
+
The returned object supports the same iteration and column-access interface
|
|
86
|
+
as the input dataset, so it can be used as a drop-in replacement everywhere
|
|
87
|
+
``hf_dataset`` is consumed.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
hf_dataset: A loaded HuggingFace Dataset (or compatible object).
|
|
91
|
+
tasks: Task IDs to keep. Rows whose task column value is not in this
|
|
92
|
+
list are excluded.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: If no task ID column is found in ``hf_dataset.features``.
|
|
96
|
+
"""
|
|
97
|
+
col = _task_id_column(hf_dataset.features)
|
|
98
|
+
if col is None:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Task filtering requested but no task ID column found in this dataset. "
|
|
101
|
+
f"Expected one of: {_TASK_ID_COLUMNS}. "
|
|
102
|
+
f"Got: {list(hf_dataset.features.keys())}"
|
|
103
|
+
)
|
|
104
|
+
return _FilteredDataset(hf_dataset, frozenset(tasks))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def load_hf_dataset(dataset_name: str, split: str, cache_dir: Optional[Path] = None) -> Any:
|
|
108
|
+
"""Load a HuggingFace dataset for the given split.
|
|
109
|
+
|
|
110
|
+
The datasets library handles download and caching internally.
|
|
111
|
+
``cache_dir`` is passed as the HF cache root so dataset files are
|
|
112
|
+
co-located with the robotdataset TED memmaps.
|
|
113
|
+
"""
|
|
114
|
+
datasets = _require_datasets()
|
|
115
|
+
kwargs: Dict[str, Any] = {}
|
|
116
|
+
if cache_dir is not None:
|
|
117
|
+
kwargs["cache_dir"] = str(cache_dir / "hf_cache")
|
|
118
|
+
return datasets.load_dataset(dataset_name, split=split, **kwargs)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def get_episode_ids(hf_dataset: Any) -> List[int]:
|
|
122
|
+
"""Return sorted unique episode IDs present in the dataset."""
|
|
123
|
+
col = _episode_id_column(hf_dataset.features)
|
|
124
|
+
return sorted(set(hf_dataset[col]))
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _convert_leaf(val: Any) -> Any:
|
|
128
|
+
"""Convert a single HF dataset leaf value to a torch.Tensor or str."""
|
|
129
|
+
# PIL Image → uint8 numpy → torch
|
|
130
|
+
try:
|
|
131
|
+
from PIL.Image import Image as PILImage
|
|
132
|
+
if isinstance(val, PILImage):
|
|
133
|
+
return torch.from_numpy(np.ascontiguousarray(np.array(val, dtype=np.uint8)))
|
|
134
|
+
except ImportError:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
if isinstance(val, torch.Tensor):
|
|
138
|
+
return val
|
|
139
|
+
|
|
140
|
+
if isinstance(val, np.ndarray):
|
|
141
|
+
if val.dtype.kind in {"S", "U", "O"}:
|
|
142
|
+
return val.tolist() if val.ndim > 0 else str(val.flat[0])
|
|
143
|
+
return torch.from_numpy(np.ascontiguousarray(val))
|
|
144
|
+
|
|
145
|
+
if isinstance(val, np.generic):
|
|
146
|
+
if val.dtype.kind in {"S", "U", "O"}:
|
|
147
|
+
item = val.item()
|
|
148
|
+
return item.decode("utf-8") if isinstance(item, bytes) else str(item)
|
|
149
|
+
return torch.tensor(val.item())
|
|
150
|
+
|
|
151
|
+
if isinstance(val, list) and val and isinstance(val[0], (int, float)):
|
|
152
|
+
try:
|
|
153
|
+
return torch.tensor(val, dtype=torch.float32)
|
|
154
|
+
except Exception:
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
# bool must come before int since bool is a subclass of int
|
|
158
|
+
if isinstance(val, bool):
|
|
159
|
+
return torch.tensor(val)
|
|
160
|
+
if isinstance(val, (int, float)):
|
|
161
|
+
return torch.tensor(val)
|
|
162
|
+
|
|
163
|
+
if isinstance(val, (bytes, bytearray)):
|
|
164
|
+
try:
|
|
165
|
+
return val.decode("utf-8")
|
|
166
|
+
except Exception:
|
|
167
|
+
return val
|
|
168
|
+
|
|
169
|
+
return val
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _to_nested(flat: Dict[str, Any]) -> Dict[str, Any]:
|
|
173
|
+
"""Reconstruct a nested dict from dot-separated flat keys.
|
|
174
|
+
|
|
175
|
+
E.g. {"observation.image": v1, "observation.state": v2}
|
|
176
|
+
→ {"observation": {"image": v1, "state": v2}}
|
|
177
|
+
"""
|
|
178
|
+
nested: Dict[str, Any] = {}
|
|
179
|
+
for key, val in flat.items():
|
|
180
|
+
parts = key.split(".")
|
|
181
|
+
d = nested
|
|
182
|
+
for part in parts[:-1]:
|
|
183
|
+
d = d.setdefault(part, {})
|
|
184
|
+
d[parts[-1]] = val
|
|
185
|
+
return nested
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def hf_episode_to_oxe_format(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
189
|
+
"""Convert a list of HF dataset rows to an OXE-style episode dict.
|
|
190
|
+
|
|
191
|
+
The result is compatible with ``episode_to_ted_steps()``:
|
|
192
|
+
{"steps": [{observation: {...}, action: T, reward: T, is_last: T, ...}, ...]}
|
|
193
|
+
|
|
194
|
+
Handles:
|
|
195
|
+
- Dot-separated column names (LeRobot style) → nested dicts
|
|
196
|
+
- ``next.*`` sub-keys → top-level ``reward`` / ``is_last`` / ``is_terminal``
|
|
197
|
+
- PIL images → uint8 HWC tensors
|
|
198
|
+
- Numpy arrays / Python lists → torch tensors
|
|
199
|
+
"""
|
|
200
|
+
steps = []
|
|
201
|
+
for row in rows:
|
|
202
|
+
converted = {k: _convert_leaf(v) for k, v in row.items()}
|
|
203
|
+
nested = _to_nested(converted)
|
|
204
|
+
|
|
205
|
+
# Lift LeRobot-style "next" sub-dict to OXE top-level field names
|
|
206
|
+
next_dict = nested.pop("next", {})
|
|
207
|
+
if "reward" in next_dict:
|
|
208
|
+
nested.setdefault("reward", next_dict["reward"])
|
|
209
|
+
if "done" in next_dict:
|
|
210
|
+
nested.setdefault("is_last", next_dict["done"])
|
|
211
|
+
if "terminated" in next_dict:
|
|
212
|
+
nested.setdefault("is_terminal", next_dict["terminated"])
|
|
213
|
+
|
|
214
|
+
# Remove bookkeeping metadata columns
|
|
215
|
+
for col in _META_COLUMNS:
|
|
216
|
+
nested.pop(col, None)
|
|
217
|
+
|
|
218
|
+
steps.append(nested)
|
|
219
|
+
return {"steps": steps}
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def infer_modalities_from_storage(td: Any) -> Dict[str, Dict[str, Any]]:
|
|
223
|
+
"""Return ``{path: spec_dict}`` by inspecting a loaded TED TensorDict.
|
|
224
|
+
|
|
225
|
+
Infers modalities directly from the built storage rather than the source
|
|
226
|
+
dataset schema, so it works regardless of the upstream data format.
|
|
227
|
+
"""
|
|
228
|
+
modalities: Dict[str, Dict[str, Any]] = {}
|
|
229
|
+
|
|
230
|
+
def _visit(t: Any, prefix: str = "") -> None:
|
|
231
|
+
for key in t.keys():
|
|
232
|
+
path = f"{prefix}/{key}" if prefix else str(key)
|
|
233
|
+
val = t.get(key)
|
|
234
|
+
if hasattr(val, "keys"):
|
|
235
|
+
_visit(val, path)
|
|
236
|
+
elif isinstance(val, torch.Tensor):
|
|
237
|
+
dtype_str = str(val.dtype).replace("torch.", "")
|
|
238
|
+
shape = tuple(int(d) for d in val.shape[1:])
|
|
239
|
+
modalities[path] = {
|
|
240
|
+
"path": path,
|
|
241
|
+
"kind": infer_kind(path),
|
|
242
|
+
"dtype": dtype_str,
|
|
243
|
+
"shape": shape,
|
|
244
|
+
"source": "storage",
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
_visit(td)
|
|
248
|
+
return modalities
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Set
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from robotdataset.hf.loader import _episode_id_column, hf_episode_to_oxe_format
|
|
10
|
+
from robotdataset.oxe.memmap_builder import is_episode_cached
|
|
11
|
+
from robotdataset.oxe.utils import episode_to_ted_steps
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from tqdm.auto import tqdm as _tqdm_cls
|
|
15
|
+
except ImportError:
|
|
16
|
+
_tqdm_cls = None # type: ignore[assignment]
|
|
17
|
+
|
|
18
|
+
_EPISODE_SENTINEL = "_steps.json"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _build_one_hf_episode(
|
|
22
|
+
episode_dict: Dict[str, Any],
|
|
23
|
+
episode_id: int,
|
|
24
|
+
episode_dir: Path,
|
|
25
|
+
) -> int:
|
|
26
|
+
"""Convert one HF episode to TED steps and memmap it to ``episode_dir``.
|
|
27
|
+
|
|
28
|
+
Returns the number of steps written.
|
|
29
|
+
"""
|
|
30
|
+
steps = episode_to_ted_steps(episode_dict, episode_id, tf_tensor_types=())
|
|
31
|
+
if not steps:
|
|
32
|
+
return 0
|
|
33
|
+
episode_dir.mkdir(parents=True, exist_ok=True)
|
|
34
|
+
td = torch.stack(steps)
|
|
35
|
+
td.memmap_(str(episode_dir))
|
|
36
|
+
n_steps = len(td)
|
|
37
|
+
(episode_dir / _EPISODE_SENTINEL).write_text(json.dumps({"n_steps": n_steps}))
|
|
38
|
+
return n_steps
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def build_missing_episodes(
|
|
42
|
+
hf_dataset: Any,
|
|
43
|
+
episodes_dir: Path,
|
|
44
|
+
missing: List[int],
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Convert and cache only the episodes in ``missing`` (episode IDs).
|
|
47
|
+
|
|
48
|
+
Streams the HF dataset once, collecting rows for all missing episode IDs,
|
|
49
|
+
then converts and memmaps each episode. Already-cached episodes are
|
|
50
|
+
skipped (safe to call if a previous run was interrupted).
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
hf_dataset: Loaded HuggingFace Dataset object.
|
|
54
|
+
episodes_dir: Root directory for per-episode memmaps.
|
|
55
|
+
missing: Episode IDs that are not yet cached.
|
|
56
|
+
"""
|
|
57
|
+
if not missing:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
missing_set: Set[int] = set(missing)
|
|
61
|
+
col = _episode_id_column(hf_dataset.features)
|
|
62
|
+
has_frame_index = "frame_index" in hf_dataset.features
|
|
63
|
+
|
|
64
|
+
# Single pass over the dataset to collect rows for all missing episodes
|
|
65
|
+
episode_rows: Dict[int, List[Any]] = {eid: [] for eid in missing_set}
|
|
66
|
+
for row in hf_dataset:
|
|
67
|
+
eid = row[col]
|
|
68
|
+
if eid in missing_set:
|
|
69
|
+
episode_rows[eid].append(dict(row))
|
|
70
|
+
|
|
71
|
+
if has_frame_index:
|
|
72
|
+
for rows in episode_rows.values():
|
|
73
|
+
rows.sort(key=lambda r: r["frame_index"])
|
|
74
|
+
|
|
75
|
+
pbar = (
|
|
76
|
+
_tqdm_cls(
|
|
77
|
+
total=len(missing_set),
|
|
78
|
+
desc="Converting episodes to TED",
|
|
79
|
+
unit="ep",
|
|
80
|
+
dynamic_ncols=True,
|
|
81
|
+
)
|
|
82
|
+
if _tqdm_cls is not None
|
|
83
|
+
else None
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
for episode_id in sorted(missing_set):
|
|
87
|
+
episode_dir = episodes_dir / str(episode_id)
|
|
88
|
+
if is_episode_cached(episode_dir):
|
|
89
|
+
if pbar is not None:
|
|
90
|
+
pbar.update(1)
|
|
91
|
+
continue
|
|
92
|
+
rows = episode_rows.get(episode_id, [])
|
|
93
|
+
if rows:
|
|
94
|
+
episode_dict = hf_episode_to_oxe_format(rows)
|
|
95
|
+
_build_one_hf_episode(episode_dict, episode_id, episode_dir)
|
|
96
|
+
if pbar is not None:
|
|
97
|
+
pbar.update(1)
|
|
98
|
+
|
|
99
|
+
if pbar is not None:
|
|
100
|
+
pbar.close()
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Internal helpers for OXE dataset loading."""
|
|
2
|
+
|
|
3
|
+
from robotdataset.oxe.bucket import discover_dataset_versions, discover_datasets_from_bucket
|
|
4
|
+
from robotdataset.oxe.temporal_sampler import TemporalSampler
|
|
5
|
+
from robotdataset.oxe.utils import (
|
|
6
|
+
ModalitySpec,
|
|
7
|
+
flatten_structure,
|
|
8
|
+
infer_kind,
|
|
9
|
+
latest_version,
|
|
10
|
+
normalize_version_key,
|
|
11
|
+
shape_and_dtype,
|
|
12
|
+
tf_to_torch,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ModalitySpec",
|
|
17
|
+
"normalize_version_key",
|
|
18
|
+
"latest_version",
|
|
19
|
+
"tf_to_torch",
|
|
20
|
+
"infer_kind",
|
|
21
|
+
"shape_and_dtype",
|
|
22
|
+
"flatten_structure",
|
|
23
|
+
"discover_dataset_versions",
|
|
24
|
+
"discover_datasets_from_bucket",
|
|
25
|
+
"TemporalSampler",
|
|
26
|
+
]
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
5
|
+
from typing import Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from robotdataset.oxe.utils import latest_version
|
|
8
|
+
|
|
9
|
+
# Parallelism for the full-bucket scan (I/O bound, so threads are fine).
|
|
10
|
+
_SCAN_WORKERS = 32
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _safe_gfile_listdir(tf_module: Optional[object], path: str) -> List[str]:
|
|
14
|
+
if tf_module is None:
|
|
15
|
+
return []
|
|
16
|
+
try:
|
|
17
|
+
return tf_module.io.gfile.listdir(path)
|
|
18
|
+
except Exception as exc:
|
|
19
|
+
warnings.warn(
|
|
20
|
+
f"Could not list GCS path '{path}': {exc}. "
|
|
21
|
+
"Dataset discovery will return no results.",
|
|
22
|
+
stacklevel=3,
|
|
23
|
+
)
|
|
24
|
+
return []
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _safe_gfile_isdir(tf_module: Optional[object], path: str) -> bool:
|
|
28
|
+
if tf_module is None:
|
|
29
|
+
return False
|
|
30
|
+
try:
|
|
31
|
+
return tf_module.io.gfile.isdir(path)
|
|
32
|
+
except Exception:
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _join(base: str, name: str) -> str:
|
|
37
|
+
"""Join a GCS base URL and an entry name, stripping any stray slashes."""
|
|
38
|
+
return f"{base.rstrip('/')}/{name.strip('/')}"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _discover_one_dataset(
|
|
42
|
+
tf_module: object, bucket_url: str, raw_name: str
|
|
43
|
+
) -> Optional[tuple[str, Dict[str, str]]]:
|
|
44
|
+
"""Return (dataset_name, {version: path}) for a single bucket entry, or None."""
|
|
45
|
+
dataset_name = raw_name.strip("/")
|
|
46
|
+
dataset_root = _join(bucket_url, dataset_name)
|
|
47
|
+
if not _safe_gfile_isdir(tf_module, dataset_root):
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
entries = _safe_gfile_listdir(tf_module, dataset_root)
|
|
51
|
+
|
|
52
|
+
def _is_version_dir(entry: str) -> bool:
|
|
53
|
+
return _safe_gfile_isdir(tf_module, _join(dataset_root, entry))
|
|
54
|
+
|
|
55
|
+
with ThreadPoolExecutor(max_workers=_SCAN_WORKERS) as pool:
|
|
56
|
+
flags = list(pool.map(_is_version_dir, entries))
|
|
57
|
+
|
|
58
|
+
versions = [e.strip("/") for e, is_dir in zip(entries, flags) if is_dir]
|
|
59
|
+
|
|
60
|
+
if versions:
|
|
61
|
+
return dataset_name, {v: _join(dataset_root, v) for v in versions}
|
|
62
|
+
return dataset_name, {"": dataset_root}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def discover_dataset_versions(
|
|
66
|
+
tf_module: Optional[object], bucket_url: str, dataset_name: str
|
|
67
|
+
) -> Dict[str, str]:
|
|
68
|
+
"""Return {version: gcs_path} for a single dataset by direct path lookup."""
|
|
69
|
+
if tf_module is None:
|
|
70
|
+
return {}
|
|
71
|
+
dataset_root = _join(bucket_url, dataset_name)
|
|
72
|
+
entries = _safe_gfile_listdir(tf_module, dataset_root)
|
|
73
|
+
|
|
74
|
+
def _is_dir(entry: str) -> bool:
|
|
75
|
+
return _safe_gfile_isdir(tf_module, _join(dataset_root, entry))
|
|
76
|
+
|
|
77
|
+
with ThreadPoolExecutor(max_workers=_SCAN_WORKERS) as pool:
|
|
78
|
+
flags = list(pool.map(_is_dir, entries))
|
|
79
|
+
|
|
80
|
+
versions = [e.strip("/") for e, is_dir in zip(entries, flags) if is_dir]
|
|
81
|
+
if versions:
|
|
82
|
+
return {v: _join(dataset_root, v) for v in versions}
|
|
83
|
+
if _safe_gfile_isdir(tf_module, dataset_root):
|
|
84
|
+
return {"": dataset_root}
|
|
85
|
+
return {}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def discover_datasets_from_bucket(
|
|
89
|
+
tf_module: Optional[object], bucket_url: str
|
|
90
|
+
) -> Dict[str, Dict[str, str]]:
|
|
91
|
+
"""Return {dataset_name: {version: gcs_path}} for every dataset in the bucket.
|
|
92
|
+
|
|
93
|
+
The per-dataset isdir/listdir calls are issued in parallel so the full scan
|
|
94
|
+
completes in roughly one round-trip worth of latency rather than O(N).
|
|
95
|
+
"""
|
|
96
|
+
if tf_module is None:
|
|
97
|
+
return {}
|
|
98
|
+
|
|
99
|
+
raw_names = _safe_gfile_listdir(tf_module, bucket_url)
|
|
100
|
+
|
|
101
|
+
discovered: Dict[str, Dict[str, str]] = {}
|
|
102
|
+
with ThreadPoolExecutor(max_workers=_SCAN_WORKERS) as pool:
|
|
103
|
+
futures = {
|
|
104
|
+
pool.submit(_discover_one_dataset, tf_module, bucket_url, name): name
|
|
105
|
+
for name in raw_names
|
|
106
|
+
}
|
|
107
|
+
for future in as_completed(futures):
|
|
108
|
+
result = future.result()
|
|
109
|
+
if result is not None:
|
|
110
|
+
name, versions = result
|
|
111
|
+
discovered[name] = versions
|
|
112
|
+
|
|
113
|
+
return discovered
|