opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import traceback
18
+ from pathlib import Path
19
+
20
+ from datasets import get_dataset_config_info
21
+ from huggingface_hub import HfApi
22
+
23
+ from opentau import available_datasets
24
+ from opentau.datasets.lerobot_dataset import LeRobotDatasetMetadata
25
+ from opentau.datasets.utils import INFO_PATH, write_info
26
+ from opentau.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
27
+
28
+ LOCAL_DIR = Path("data/")
29
+
30
+ hub_api = HfApi()
31
+
32
+
33
+ def fix_dataset(repo_id: str) -> str:
34
+ """Remove 'language_instruction' feature from dataset metadata if present.
35
+
36
+ Checks if the dataset has a 'language_instruction' feature in metadata
37
+ that doesn't exist in parquet files, and removes it from info.json.
38
+
39
+ Args:
40
+ repo_id: Repository ID of the dataset to fix.
41
+
42
+ Returns:
43
+ Status message indicating success, skip reason, or error.
44
+
45
+ Raises:
46
+ ValueError: If there are unexpected feature differences between
47
+ parquet files and metadata, or if the difference is not
48
+ just 'language_instruction'.
49
+ """
50
+ if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
51
+ return f"{repo_id}: skipped (not in {V20})."
52
+
53
+ dataset_info = get_dataset_config_info(repo_id, "default")
54
+ with SuppressWarnings():
55
+ lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
56
+
57
+ meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
58
+ parquet_features = set(dataset_info.features)
59
+
60
+ diff_parquet_meta = parquet_features - meta_features
61
+ diff_meta_parquet = meta_features - parquet_features
62
+
63
+ if diff_parquet_meta:
64
+ raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
65
+
66
+ if not diff_meta_parquet:
67
+ return f"{repo_id}: skipped (no diff)"
68
+
69
+ if diff_meta_parquet:
70
+ logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
71
+ assert diff_meta_parquet == {"language_instruction"}
72
+ lerobot_metadata.features.pop("language_instruction")
73
+ write_info(lerobot_metadata.info, lerobot_metadata.root)
74
+ commit_info = hub_api.upload_file(
75
+ path_or_fileobj=lerobot_metadata.root / INFO_PATH,
76
+ path_in_repo=INFO_PATH,
77
+ repo_id=repo_id,
78
+ repo_type="dataset",
79
+ revision=V20,
80
+ commit_message="Remove 'language_instruction'",
81
+ create_pr=True,
82
+ )
83
+ return f"{repo_id}: success - PR: {commit_info.pr_url}"
84
+
85
+
86
+ def batch_fix() -> None:
87
+ """Batch process all available datasets to remove language_instruction feature.
88
+
89
+ Iterates through all datasets in available_datasets and attempts to fix
90
+ each one, logging results to a file.
91
+ """
92
+ status = {}
93
+ LOCAL_DIR.mkdir(parents=True, exist_ok=True)
94
+ logfile = LOCAL_DIR / "fix_features_v20.txt"
95
+ for num, repo_id in enumerate(available_datasets):
96
+ print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
97
+ print("---------------------------------------------------------")
98
+ try:
99
+ status = fix_dataset(repo_id)
100
+ except Exception:
101
+ status = f"{repo_id}: failed\n {traceback.format_exc()}"
102
+
103
+ logging.info(status)
104
+ with open(logfile, "a") as file:
105
+ file.write(status + "\n")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ batch_fix()
@@ -0,0 +1,60 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
20
+ """
21
+
22
+ import traceback
23
+ from pathlib import Path
24
+
25
+ from huggingface_hub import HfApi
26
+
27
+ from opentau import available_datasets
28
+ from opentau.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
29
+
30
+ LOCAL_DIR = Path("data/")
31
+
32
+
33
+ def batch_convert() -> None:
34
+ """Batch convert multiple datasets from v2.0 to v2.1 format.
35
+
36
+ Processes all datasets in available_datasets, converting each one
37
+ and logging the results to a file. Skips datasets already in v2.1.
38
+ """
39
+ status = {}
40
+ LOCAL_DIR.mkdir(parents=True, exist_ok=True)
41
+ logfile = LOCAL_DIR / "conversion_log_v21.txt"
42
+ hub_api = HfApi()
43
+ for num, repo_id in enumerate(available_datasets):
44
+ print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
45
+ print("---------------------------------------------------------")
46
+ try:
47
+ if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
48
+ status = f"{repo_id}: success (already in {V21})."
49
+ else:
50
+ convert_dataset(repo_id)
51
+ status = f"{repo_id}: success."
52
+ except Exception:
53
+ status = f"{repo_id}: failed\n {traceback.format_exc()}"
54
+
55
+ with open(logfile, "a") as file:
56
+ file.write(status + "\n")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ batch_convert()
@@ -0,0 +1,183 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
18
+ 2.1. It will:
19
+
20
+ - Generate per-episodes stats and writes them in `episodes_stats.jsonl`
21
+ - Check consistency between these new stats and the old ones.
22
+ - Remove the deprecated `stats.json`.
23
+ - Update codebase_version in `info.json`.
24
+ - Push this new version to the hub on the 'main' branch and tags it with "v2.1".
25
+
26
+ Usage:
27
+
28
+ ```bash
29
+ python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
30
+ --repo-id=aliberts/koch_tutorial
31
+ ```
32
+
33
+ """
34
+
35
+ import argparse
36
+ import logging
37
+ from dataclasses import dataclass
38
+
39
+ from huggingface_hub import HfApi
40
+
41
+ from opentau.configs.default import DatasetConfig, DatasetMixtureConfig
42
+ from opentau.configs.policies import PreTrainedConfig
43
+ from opentau.configs.train import TrainPipelineConfig
44
+ from opentau.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
45
+ from opentau.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
46
+ from opentau.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
47
+
48
+ V20 = "v2.0"
49
+ V21 = "v2.1"
50
+
51
+
52
+ class SuppressWarnings:
53
+ """Context manager to temporarily suppress logging warnings.
54
+
55
+ Sets logging level to ERROR on entry and restores previous level on exit.
56
+ """
57
+
58
+ def __enter__(self):
59
+ self.previous_level = logging.getLogger().getEffectiveLevel()
60
+ logging.getLogger().setLevel(logging.ERROR)
61
+
62
+ def __exit__(self, exc_type, exc_val, exc_tb):
63
+ logging.getLogger().setLevel(self.previous_level)
64
+
65
+
66
+ def create_fake_train_config() -> TrainPipelineConfig:
67
+ """Create a fake TrainPipelineConfig for dataset conversion."""
68
+
69
+ # Minimal dummy PreTrainedConfig implementation
70
+ @dataclass
71
+ class DummyPolicyConfig(PreTrainedConfig):
72
+ @property
73
+ def observation_delta_indices(self):
74
+ return None
75
+
76
+ @property
77
+ def action_delta_indices(self):
78
+ return None
79
+
80
+ @property
81
+ def reward_delta_indices(self):
82
+ return None
83
+
84
+ def get_optimizer_preset(self):
85
+ return None
86
+
87
+ def get_scheduler_preset(self):
88
+ return None
89
+
90
+ def validate_features(self):
91
+ pass
92
+
93
+ # Create minimal config components
94
+ dataset_cfg = DatasetConfig(repo_id="dummy") # Will be overridden by LeRobotDataset
95
+ mixture_cfg = DatasetMixtureConfig(datasets=[dataset_cfg], weights=[1.0])
96
+ policy_cfg = DummyPolicyConfig()
97
+
98
+ # Create the main config with minimal required parameters
99
+ cfg = TrainPipelineConfig(
100
+ dataset_mixture=mixture_cfg,
101
+ policy=policy_cfg,
102
+ resolution=(224, 224),
103
+ num_cams=2,
104
+ max_state_dim=32,
105
+ max_action_dim=32,
106
+ action_chunk=50,
107
+ )
108
+
109
+ return cfg
110
+
111
+
112
+ def convert_dataset(
113
+ repo_id: str,
114
+ branch: str | None = None,
115
+ num_workers: int = 4,
116
+ ) -> None:
117
+ """Convert a dataset from v2.0 to v2.1 format.
118
+
119
+ Converts statistics from global format to per-episode format, updates
120
+ codebase version, and pushes changes to the hub.
121
+
122
+ Args:
123
+ repo_id: Repository ID of the dataset to convert.
124
+ branch: Git branch to push changes to. If None, uses default branch.
125
+ num_workers: Number of worker threads for parallel statistics computation.
126
+ Defaults to 4.
127
+ """
128
+ with SuppressWarnings():
129
+ # Create fake config for the dataset
130
+ cfg = create_fake_train_config()
131
+ dataset = LeRobotDataset(cfg, repo_id, revision=V20, force_cache_sync=True)
132
+
133
+ if (dataset.root / EPISODES_STATS_PATH).is_file():
134
+ (dataset.root / EPISODES_STATS_PATH).unlink()
135
+
136
+ convert_stats(dataset, num_workers=num_workers)
137
+ ref_stats = load_stats(dataset.root)
138
+ check_aggregate_stats(dataset, ref_stats)
139
+
140
+ dataset.meta.info["codebase_version"] = CODEBASE_VERSION
141
+ write_info(dataset.meta.info, dataset.root)
142
+
143
+ dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
144
+
145
+ # delete old stats.json file
146
+ if (dataset.root / STATS_PATH).is_file:
147
+ (dataset.root / STATS_PATH).unlink()
148
+
149
+ hub_api = HfApi()
150
+ if hub_api.file_exists(
151
+ repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
152
+ ):
153
+ hub_api.delete_file(
154
+ path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
155
+ )
156
+
157
+ hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
158
+
159
+
160
+ if __name__ == "__main__":
161
+ parser = argparse.ArgumentParser()
162
+ parser.add_argument(
163
+ "--repo-id",
164
+ type=str,
165
+ required=True,
166
+ help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
167
+ "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
168
+ )
169
+ parser.add_argument(
170
+ "--branch",
171
+ type=str,
172
+ default=None,
173
+ help="Repo branch to push your dataset. Defaults to the main branch.",
174
+ )
175
+ parser.add_argument(
176
+ "--num-workers",
177
+ type=int,
178
+ default=4,
179
+ help="Number of workers for parallelizing stats compute. Defaults to 4.",
180
+ )
181
+
182
+ args = parser.parse_args()
183
+ convert_dataset(**vars(args))
@@ -0,0 +1,150 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+
18
+ import numpy as np
19
+ from tqdm import tqdm
20
+
21
+ from opentau.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
22
+ from opentau.datasets.lerobot_dataset import LeRobotDataset
23
+ from opentau.datasets.utils import write_episode_stats
24
+
25
+
26
+ def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
27
+ """Sample video frames from an episode for statistics computation.
28
+
29
+ Uses evenly spaced sampling to reduce the number of frames processed
30
+ while maintaining representative statistics.
31
+
32
+ Args:
33
+ dataset: LeRobotDataset containing the episode.
34
+ episode_index: Index of the episode to sample from.
35
+ ft_key: Feature key for the video to sample.
36
+
37
+ Returns:
38
+ Numpy array of sampled video frames.
39
+ """
40
+ ep_len = dataset.meta.episodes[episode_index]["length"]
41
+ sampled_indices = sample_indices(ep_len)
42
+ query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
43
+ video_frames = dataset._query_videos(query_timestamps, episode_index)
44
+ return video_frames[ft_key].numpy()
45
+
46
+
47
+ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int) -> None:
48
+ """Convert statistics for a single episode from v2.0 to v2.1 format.
49
+
50
+ Computes per-episode statistics, sampling video frames if needed.
51
+ Stores results in dataset.meta.episodes_stats.
52
+
53
+ Args:
54
+ dataset: LeRobotDataset containing the episode.
55
+ ep_idx: Index of the episode to convert.
56
+ """
57
+ ep_start_idx = dataset.episode_data_index["from"][ep_idx]
58
+ ep_end_idx = dataset.episode_data_index["to"][ep_idx]
59
+ ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
60
+
61
+ ep_stats = {}
62
+ for key, ft in dataset.features.items():
63
+ if ft["dtype"] == "video":
64
+ # We sample only for videos
65
+ ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
66
+ else:
67
+ ep_ft_data = np.array(ep_data[key])
68
+
69
+ axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
70
+ keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
71
+ ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
72
+
73
+ if ft["dtype"] in ["image", "video"]: # remove batch dim
74
+ ep_stats[key] = {
75
+ k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
76
+ }
77
+
78
+ dataset.meta.episodes_stats[ep_idx] = ep_stats
79
+
80
+
81
+ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0) -> None:
82
+ """Convert dataset statistics from v2.0 to v2.1 format.
83
+
84
+ Computes per-episode statistics for all episodes and writes them to disk.
85
+ Can use parallel processing with ThreadPoolExecutor for faster computation.
86
+
87
+ Args:
88
+ dataset: LeRobotDataset to convert (must have episodes=None to process all).
89
+ num_workers: Number of worker threads for parallel processing.
90
+ If 0, processes sequentially. Defaults to 0.
91
+
92
+ Raises:
93
+ AssertionError: If dataset.episodes is not None.
94
+ """
95
+ assert dataset.episodes is None
96
+ print("Computing episodes stats")
97
+ total_episodes = dataset.meta.total_episodes
98
+ if num_workers > 0:
99
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
100
+ futures = {
101
+ executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
102
+ for ep_idx in range(total_episodes)
103
+ }
104
+ for future in tqdm(as_completed(futures), total=total_episodes):
105
+ future.result()
106
+ else:
107
+ for ep_idx in tqdm(range(total_episodes)):
108
+ convert_episode_stats(dataset, ep_idx)
109
+
110
+ for ep_idx in tqdm(range(total_episodes)):
111
+ write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
112
+
113
+
114
+ def check_aggregate_stats(
115
+ dataset: LeRobotDataset,
116
+ reference_stats: dict[str, dict[str, np.ndarray]],
117
+ video_rtol_atol: tuple[float] = (1e-2, 1e-2),
118
+ default_rtol_atol: tuple[float] = (5e-6, 6e-5),
119
+ ) -> None:
120
+ """Verify that aggregated episode statistics match reference statistics.
121
+
122
+ Aggregates per-episode statistics and compares them to reference stats
123
+ with appropriate tolerances for different feature types.
124
+
125
+ Args:
126
+ dataset: LeRobotDataset with episodes_stats populated.
127
+ reference_stats: Reference statistics dictionary to compare against.
128
+ video_rtol_atol: Relative and absolute tolerance for video features.
129
+ Defaults to (1e-2, 1e-2) to account for image sub-sampling.
130
+ default_rtol_atol: Relative and absolute tolerance for other features.
131
+ Defaults to (5e-6, 6e-5).
132
+
133
+ Raises:
134
+ AssertionError: If aggregated stats don't match reference within tolerance.
135
+ """
136
+ agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
137
+ for key, ft in dataset.features.items():
138
+ # These values might need some fine-tuning
139
+ if ft["dtype"] == "video":
140
+ # to account for image sub-sampling
141
+ rtol, atol = video_rtol_atol
142
+ else:
143
+ rtol, atol = default_rtol_atol
144
+
145
+ for stat, val in agg_stats[key].items():
146
+ if key in reference_stats and stat in reference_stats[key]:
147
+ err_msg = f"feature='{key}' stats='{stat}'"
148
+ np.testing.assert_allclose(
149
+ val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
150
+ )