opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import traceback
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
from datasets import get_dataset_config_info
|
|
21
|
+
from huggingface_hub import HfApi
|
|
22
|
+
|
|
23
|
+
from opentau import available_datasets
|
|
24
|
+
from opentau.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
|
25
|
+
from opentau.datasets.utils import INFO_PATH, write_info
|
|
26
|
+
from opentau.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
|
|
27
|
+
|
|
28
|
+
LOCAL_DIR = Path("data/")
|
|
29
|
+
|
|
30
|
+
hub_api = HfApi()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def fix_dataset(repo_id: str) -> str:
|
|
34
|
+
"""Remove 'language_instruction' feature from dataset metadata if present.
|
|
35
|
+
|
|
36
|
+
Checks if the dataset has a 'language_instruction' feature in metadata
|
|
37
|
+
that doesn't exist in parquet files, and removes it from info.json.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
repo_id: Repository ID of the dataset to fix.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Status message indicating success, skip reason, or error.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If there are unexpected feature differences between
|
|
47
|
+
parquet files and metadata, or if the difference is not
|
|
48
|
+
just 'language_instruction'.
|
|
49
|
+
"""
|
|
50
|
+
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
|
|
51
|
+
return f"{repo_id}: skipped (not in {V20})."
|
|
52
|
+
|
|
53
|
+
dataset_info = get_dataset_config_info(repo_id, "default")
|
|
54
|
+
with SuppressWarnings():
|
|
55
|
+
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
|
56
|
+
|
|
57
|
+
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
|
58
|
+
parquet_features = set(dataset_info.features)
|
|
59
|
+
|
|
60
|
+
diff_parquet_meta = parquet_features - meta_features
|
|
61
|
+
diff_meta_parquet = meta_features - parquet_features
|
|
62
|
+
|
|
63
|
+
if diff_parquet_meta:
|
|
64
|
+
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
|
65
|
+
|
|
66
|
+
if not diff_meta_parquet:
|
|
67
|
+
return f"{repo_id}: skipped (no diff)"
|
|
68
|
+
|
|
69
|
+
if diff_meta_parquet:
|
|
70
|
+
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
|
71
|
+
assert diff_meta_parquet == {"language_instruction"}
|
|
72
|
+
lerobot_metadata.features.pop("language_instruction")
|
|
73
|
+
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
|
74
|
+
commit_info = hub_api.upload_file(
|
|
75
|
+
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
|
|
76
|
+
path_in_repo=INFO_PATH,
|
|
77
|
+
repo_id=repo_id,
|
|
78
|
+
repo_type="dataset",
|
|
79
|
+
revision=V20,
|
|
80
|
+
commit_message="Remove 'language_instruction'",
|
|
81
|
+
create_pr=True,
|
|
82
|
+
)
|
|
83
|
+
return f"{repo_id}: success - PR: {commit_info.pr_url}"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def batch_fix() -> None:
|
|
87
|
+
"""Batch process all available datasets to remove language_instruction feature.
|
|
88
|
+
|
|
89
|
+
Iterates through all datasets in available_datasets and attempts to fix
|
|
90
|
+
each one, logging results to a file.
|
|
91
|
+
"""
|
|
92
|
+
status = {}
|
|
93
|
+
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
|
94
|
+
logfile = LOCAL_DIR / "fix_features_v20.txt"
|
|
95
|
+
for num, repo_id in enumerate(available_datasets):
|
|
96
|
+
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
|
97
|
+
print("---------------------------------------------------------")
|
|
98
|
+
try:
|
|
99
|
+
status = fix_dataset(repo_id)
|
|
100
|
+
except Exception:
|
|
101
|
+
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
|
102
|
+
|
|
103
|
+
logging.info(status)
|
|
104
|
+
with open(logfile, "a") as file:
|
|
105
|
+
file.write(status + "\n")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == "__main__":
|
|
109
|
+
batch_fix()
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import traceback
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
from huggingface_hub import HfApi
|
|
26
|
+
|
|
27
|
+
from opentau import available_datasets
|
|
28
|
+
from opentau.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
|
|
29
|
+
|
|
30
|
+
LOCAL_DIR = Path("data/")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def batch_convert() -> None:
|
|
34
|
+
"""Batch convert multiple datasets from v2.0 to v2.1 format.
|
|
35
|
+
|
|
36
|
+
Processes all datasets in available_datasets, converting each one
|
|
37
|
+
and logging the results to a file. Skips datasets already in v2.1.
|
|
38
|
+
"""
|
|
39
|
+
status = {}
|
|
40
|
+
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
|
42
|
+
hub_api = HfApi()
|
|
43
|
+
for num, repo_id in enumerate(available_datasets):
|
|
44
|
+
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
|
45
|
+
print("---------------------------------------------------------")
|
|
46
|
+
try:
|
|
47
|
+
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
|
|
48
|
+
status = f"{repo_id}: success (already in {V21})."
|
|
49
|
+
else:
|
|
50
|
+
convert_dataset(repo_id)
|
|
51
|
+
status = f"{repo_id}: success."
|
|
52
|
+
except Exception:
|
|
53
|
+
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
|
54
|
+
|
|
55
|
+
with open(logfile, "a") as file:
|
|
56
|
+
file.write(status + "\n")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
if __name__ == "__main__":
|
|
60
|
+
batch_convert()
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
|
18
|
+
2.1. It will:
|
|
19
|
+
|
|
20
|
+
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
|
21
|
+
- Check consistency between these new stats and the old ones.
|
|
22
|
+
- Remove the deprecated `stats.json`.
|
|
23
|
+
- Update codebase_version in `info.json`.
|
|
24
|
+
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
|
25
|
+
|
|
26
|
+
Usage:
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
|
30
|
+
--repo-id=aliberts/koch_tutorial
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
import argparse
|
|
36
|
+
import logging
|
|
37
|
+
from dataclasses import dataclass
|
|
38
|
+
|
|
39
|
+
from huggingface_hub import HfApi
|
|
40
|
+
|
|
41
|
+
from opentau.configs.default import DatasetConfig, DatasetMixtureConfig
|
|
42
|
+
from opentau.configs.policies import PreTrainedConfig
|
|
43
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
44
|
+
from opentau.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
|
45
|
+
from opentau.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
|
46
|
+
from opentau.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
|
47
|
+
|
|
48
|
+
V20 = "v2.0"
|
|
49
|
+
V21 = "v2.1"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SuppressWarnings:
|
|
53
|
+
"""Context manager to temporarily suppress logging warnings.
|
|
54
|
+
|
|
55
|
+
Sets logging level to ERROR on entry and restores previous level on exit.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __enter__(self):
|
|
59
|
+
self.previous_level = logging.getLogger().getEffectiveLevel()
|
|
60
|
+
logging.getLogger().setLevel(logging.ERROR)
|
|
61
|
+
|
|
62
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
63
|
+
logging.getLogger().setLevel(self.previous_level)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def create_fake_train_config() -> TrainPipelineConfig:
|
|
67
|
+
"""Create a fake TrainPipelineConfig for dataset conversion."""
|
|
68
|
+
|
|
69
|
+
# Minimal dummy PreTrainedConfig implementation
|
|
70
|
+
@dataclass
|
|
71
|
+
class DummyPolicyConfig(PreTrainedConfig):
|
|
72
|
+
@property
|
|
73
|
+
def observation_delta_indices(self):
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def action_delta_indices(self):
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def reward_delta_indices(self):
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
def get_optimizer_preset(self):
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
def get_scheduler_preset(self):
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
def validate_features(self):
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
# Create minimal config components
|
|
94
|
+
dataset_cfg = DatasetConfig(repo_id="dummy") # Will be overridden by LeRobotDataset
|
|
95
|
+
mixture_cfg = DatasetMixtureConfig(datasets=[dataset_cfg], weights=[1.0])
|
|
96
|
+
policy_cfg = DummyPolicyConfig()
|
|
97
|
+
|
|
98
|
+
# Create the main config with minimal required parameters
|
|
99
|
+
cfg = TrainPipelineConfig(
|
|
100
|
+
dataset_mixture=mixture_cfg,
|
|
101
|
+
policy=policy_cfg,
|
|
102
|
+
resolution=(224, 224),
|
|
103
|
+
num_cams=2,
|
|
104
|
+
max_state_dim=32,
|
|
105
|
+
max_action_dim=32,
|
|
106
|
+
action_chunk=50,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return cfg
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def convert_dataset(
|
|
113
|
+
repo_id: str,
|
|
114
|
+
branch: str | None = None,
|
|
115
|
+
num_workers: int = 4,
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Convert a dataset from v2.0 to v2.1 format.
|
|
118
|
+
|
|
119
|
+
Converts statistics from global format to per-episode format, updates
|
|
120
|
+
codebase version, and pushes changes to the hub.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
repo_id: Repository ID of the dataset to convert.
|
|
124
|
+
branch: Git branch to push changes to. If None, uses default branch.
|
|
125
|
+
num_workers: Number of worker threads for parallel statistics computation.
|
|
126
|
+
Defaults to 4.
|
|
127
|
+
"""
|
|
128
|
+
with SuppressWarnings():
|
|
129
|
+
# Create fake config for the dataset
|
|
130
|
+
cfg = create_fake_train_config()
|
|
131
|
+
dataset = LeRobotDataset(cfg, repo_id, revision=V20, force_cache_sync=True)
|
|
132
|
+
|
|
133
|
+
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
|
134
|
+
(dataset.root / EPISODES_STATS_PATH).unlink()
|
|
135
|
+
|
|
136
|
+
convert_stats(dataset, num_workers=num_workers)
|
|
137
|
+
ref_stats = load_stats(dataset.root)
|
|
138
|
+
check_aggregate_stats(dataset, ref_stats)
|
|
139
|
+
|
|
140
|
+
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
|
141
|
+
write_info(dataset.meta.info, dataset.root)
|
|
142
|
+
|
|
143
|
+
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
|
144
|
+
|
|
145
|
+
# delete old stats.json file
|
|
146
|
+
if (dataset.root / STATS_PATH).is_file:
|
|
147
|
+
(dataset.root / STATS_PATH).unlink()
|
|
148
|
+
|
|
149
|
+
hub_api = HfApi()
|
|
150
|
+
if hub_api.file_exists(
|
|
151
|
+
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
|
152
|
+
):
|
|
153
|
+
hub_api.delete_file(
|
|
154
|
+
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
if __name__ == "__main__":
|
|
161
|
+
parser = argparse.ArgumentParser()
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--repo-id",
|
|
164
|
+
type=str,
|
|
165
|
+
required=True,
|
|
166
|
+
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
|
167
|
+
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
|
168
|
+
)
|
|
169
|
+
parser.add_argument(
|
|
170
|
+
"--branch",
|
|
171
|
+
type=str,
|
|
172
|
+
default=None,
|
|
173
|
+
help="Repo branch to push your dataset. Defaults to the main branch.",
|
|
174
|
+
)
|
|
175
|
+
parser.add_argument(
|
|
176
|
+
"--num-workers",
|
|
177
|
+
type=int,
|
|
178
|
+
default=4,
|
|
179
|
+
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
args = parser.parse_args()
|
|
183
|
+
convert_dataset(**vars(args))
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
from opentau.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
|
22
|
+
from opentau.datasets.lerobot_dataset import LeRobotDataset
|
|
23
|
+
from opentau.datasets.utils import write_episode_stats
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
|
27
|
+
"""Sample video frames from an episode for statistics computation.
|
|
28
|
+
|
|
29
|
+
Uses evenly spaced sampling to reduce the number of frames processed
|
|
30
|
+
while maintaining representative statistics.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
dataset: LeRobotDataset containing the episode.
|
|
34
|
+
episode_index: Index of the episode to sample from.
|
|
35
|
+
ft_key: Feature key for the video to sample.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Numpy array of sampled video frames.
|
|
39
|
+
"""
|
|
40
|
+
ep_len = dataset.meta.episodes[episode_index]["length"]
|
|
41
|
+
sampled_indices = sample_indices(ep_len)
|
|
42
|
+
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
|
43
|
+
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
|
44
|
+
return video_frames[ft_key].numpy()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int) -> None:
|
|
48
|
+
"""Convert statistics for a single episode from v2.0 to v2.1 format.
|
|
49
|
+
|
|
50
|
+
Computes per-episode statistics, sampling video frames if needed.
|
|
51
|
+
Stores results in dataset.meta.episodes_stats.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
dataset: LeRobotDataset containing the episode.
|
|
55
|
+
ep_idx: Index of the episode to convert.
|
|
56
|
+
"""
|
|
57
|
+
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
|
58
|
+
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
|
59
|
+
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
|
60
|
+
|
|
61
|
+
ep_stats = {}
|
|
62
|
+
for key, ft in dataset.features.items():
|
|
63
|
+
if ft["dtype"] == "video":
|
|
64
|
+
# We sample only for videos
|
|
65
|
+
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
|
66
|
+
else:
|
|
67
|
+
ep_ft_data = np.array(ep_data[key])
|
|
68
|
+
|
|
69
|
+
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
|
70
|
+
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
|
71
|
+
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
|
72
|
+
|
|
73
|
+
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
|
74
|
+
ep_stats[key] = {
|
|
75
|
+
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0) -> None:
|
|
82
|
+
"""Convert dataset statistics from v2.0 to v2.1 format.
|
|
83
|
+
|
|
84
|
+
Computes per-episode statistics for all episodes and writes them to disk.
|
|
85
|
+
Can use parallel processing with ThreadPoolExecutor for faster computation.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
dataset: LeRobotDataset to convert (must have episodes=None to process all).
|
|
89
|
+
num_workers: Number of worker threads for parallel processing.
|
|
90
|
+
If 0, processes sequentially. Defaults to 0.
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
AssertionError: If dataset.episodes is not None.
|
|
94
|
+
"""
|
|
95
|
+
assert dataset.episodes is None
|
|
96
|
+
print("Computing episodes stats")
|
|
97
|
+
total_episodes = dataset.meta.total_episodes
|
|
98
|
+
if num_workers > 0:
|
|
99
|
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
100
|
+
futures = {
|
|
101
|
+
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
|
102
|
+
for ep_idx in range(total_episodes)
|
|
103
|
+
}
|
|
104
|
+
for future in tqdm(as_completed(futures), total=total_episodes):
|
|
105
|
+
future.result()
|
|
106
|
+
else:
|
|
107
|
+
for ep_idx in tqdm(range(total_episodes)):
|
|
108
|
+
convert_episode_stats(dataset, ep_idx)
|
|
109
|
+
|
|
110
|
+
for ep_idx in tqdm(range(total_episodes)):
|
|
111
|
+
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def check_aggregate_stats(
|
|
115
|
+
dataset: LeRobotDataset,
|
|
116
|
+
reference_stats: dict[str, dict[str, np.ndarray]],
|
|
117
|
+
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
|
118
|
+
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Verify that aggregated episode statistics match reference statistics.
|
|
121
|
+
|
|
122
|
+
Aggregates per-episode statistics and compares them to reference stats
|
|
123
|
+
with appropriate tolerances for different feature types.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
dataset: LeRobotDataset with episodes_stats populated.
|
|
127
|
+
reference_stats: Reference statistics dictionary to compare against.
|
|
128
|
+
video_rtol_atol: Relative and absolute tolerance for video features.
|
|
129
|
+
Defaults to (1e-2, 1e-2) to account for image sub-sampling.
|
|
130
|
+
default_rtol_atol: Relative and absolute tolerance for other features.
|
|
131
|
+
Defaults to (5e-6, 6e-5).
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
AssertionError: If aggregated stats don't match reference within tolerance.
|
|
135
|
+
"""
|
|
136
|
+
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
|
137
|
+
for key, ft in dataset.features.items():
|
|
138
|
+
# These values might need some fine-tuning
|
|
139
|
+
if ft["dtype"] == "video":
|
|
140
|
+
# to account for image sub-sampling
|
|
141
|
+
rtol, atol = video_rtol_atol
|
|
142
|
+
else:
|
|
143
|
+
rtol, atol = default_rtol_atol
|
|
144
|
+
|
|
145
|
+
for stat, val in agg_stats[key].items():
|
|
146
|
+
if key in reference_stats and stat in reference_stats[key]:
|
|
147
|
+
err_msg = f"feature='{key}' stats='{stat}'"
|
|
148
|
+
np.testing.assert_allclose(
|
|
149
|
+
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
|
150
|
+
)
|