opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Weighted dataset mixture for combining multiple datasets with controlled sampling.
|
|
16
|
+
|
|
17
|
+
This module provides functionality to combine multiple PyTorch datasets into a
|
|
18
|
+
single weighted mixture, enabling training on heterogeneous datasets with
|
|
19
|
+
controlled sampling proportions. It supports hierarchical sampling strategies
|
|
20
|
+
that efficiently handle large-scale dataset combinations while maintaining
|
|
21
|
+
memory efficiency.
|
|
22
|
+
|
|
23
|
+
The module implements a two-level sampling approach:
|
|
24
|
+
1. Dataset-level sampling: Selects which dataset to sample from based on
|
|
25
|
+
specified weights.
|
|
26
|
+
2. Sample-level sampling: Uniformly samples within the selected dataset.
|
|
27
|
+
|
|
28
|
+
This hierarchical approach avoids expensive multinomial sampling over millions
|
|
29
|
+
of individual samples by operating at the dataset level, making it scalable
|
|
30
|
+
for large dataset mixtures.
|
|
31
|
+
|
|
32
|
+
Key Features:
|
|
33
|
+
- Weighted sampling: Control relative sampling frequency of different
|
|
34
|
+
datasets through configurable weights.
|
|
35
|
+
- Memory-efficient sampling: Hierarchical sampler processes samples in
|
|
36
|
+
chunks to minimize memory overhead.
|
|
37
|
+
- Metadata aggregation: Automatically aggregates and standardizes metadata
|
|
38
|
+
from multiple datasets, including statistics normalization and feature
|
|
39
|
+
name mapping.
|
|
40
|
+
- Format standardization: Converts dataset-specific feature formats to a
|
|
41
|
+
common standard format, handling vector padding and missing cameras.
|
|
42
|
+
|
|
43
|
+
Classes:
|
|
44
|
+
WeightedDatasetMixture: Main class for combining multiple datasets with
|
|
45
|
+
weighted sampling. Creates concatenated datasets and provides DataLoader
|
|
46
|
+
with hierarchical sampling.
|
|
47
|
+
HierarchicalSampler: Custom PyTorch sampler that implements two-level
|
|
48
|
+
weighted sampling (dataset selection, then uniform sample selection).
|
|
49
|
+
DatasetMixtureMetadata: Aggregates metadata from multiple datasets,
|
|
50
|
+
standardizes feature names, pads vectors, and combines statistics.
|
|
51
|
+
|
|
52
|
+
Functions:
|
|
53
|
+
pad_vector: Pads the last dimension of a vector to a target size with zeros.
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
Create a dataset mixture with two datasets:
|
|
57
|
+
>>> datasets = [dataset1, dataset2]
|
|
58
|
+
>>> weights = [0.7, 0.3] # 70% from dataset1, 30% from dataset2
|
|
59
|
+
>>> mixture = WeightedDatasetMixture(cfg, datasets, weights, action_freq=30.0)
|
|
60
|
+
>>> dataloader = mixture.get_dataloader()
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
import logging
|
|
64
|
+
from typing import List, Optional
|
|
65
|
+
|
|
66
|
+
import numpy as np
|
|
67
|
+
import torch
|
|
68
|
+
from torch.utils.data import ConcatDataset, DataLoader, Sampler
|
|
69
|
+
|
|
70
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
71
|
+
from opentau.datasets.compute_stats import aggregate_stats
|
|
72
|
+
from opentau.datasets.lerobot_dataset import BaseDataset, DatasetMetadata
|
|
73
|
+
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def pad_vector(vector: np.ndarray, new_dim: int) -> np.ndarray:
|
|
77
|
+
"""Pad the last dimension of a vector to a target size with zeros.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
vector: Input numpy array to pad.
|
|
81
|
+
new_dim: Target size for the last dimension.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Padded array with the last dimension expanded to new_dim. If the
|
|
85
|
+
vector already has the target dimension, returns it unchanged.
|
|
86
|
+
"""
|
|
87
|
+
if vector.shape[-1] == new_dim:
|
|
88
|
+
return vector
|
|
89
|
+
shape = list(vector.shape)
|
|
90
|
+
current_dim = shape[-1]
|
|
91
|
+
shape[-1] = new_dim
|
|
92
|
+
new_vector = np.zeros(shape, dtype=vector.dtype)
|
|
93
|
+
new_vector[..., :current_dim] = vector
|
|
94
|
+
return new_vector
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class DatasetMixtureMetadata:
|
|
98
|
+
"""A class to hold metadata for a mixture of datasets.
|
|
99
|
+
|
|
100
|
+
This is used to aggregate metadata from multiple datasets into a single object.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self, cfg: TrainPipelineConfig, metadatas: List[DatasetMetadata], dataset_weights: List[float]
|
|
105
|
+
):
|
|
106
|
+
self.cfg = cfg
|
|
107
|
+
|
|
108
|
+
# convert each metadata stats to the standard data format
|
|
109
|
+
for metadata in metadatas:
|
|
110
|
+
metadata.stats = self._to_standard_data_format(metadata.repo_id, metadata.stats)
|
|
111
|
+
|
|
112
|
+
self.stats = aggregate_stats([metadata.stats for metadata in metadatas], weights=dataset_weights)
|
|
113
|
+
|
|
114
|
+
def _to_standard_data_format(
|
|
115
|
+
self, repo_id: str, stats: dict[str, dict[str, np.ndarray]]
|
|
116
|
+
) -> dict[str, dict[str, np.ndarray]]:
|
|
117
|
+
"""Convert statistics to the standard data format.
|
|
118
|
+
|
|
119
|
+
Maps feature names from dataset-specific format to standard format,
|
|
120
|
+
pads state and action vectors, and ensures all required cameras are present.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
repo_id: Repository ID used to look up feature name mapping.
|
|
124
|
+
stats: Statistics dictionary with dataset-specific feature names.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Statistics dictionary with standard feature names and padded vectors.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
KeyError: If a required feature is missing from stats or if required
|
|
131
|
+
statistics (mean, std, min, max) are missing.
|
|
132
|
+
"""
|
|
133
|
+
name_map = DATA_FEATURES_NAME_MAPPING[repo_id]
|
|
134
|
+
features_without_stats = ["prompt", "response", "advantage"]
|
|
135
|
+
|
|
136
|
+
standard_stats = {}
|
|
137
|
+
for new_key, key in name_map.items():
|
|
138
|
+
if new_key in features_without_stats:
|
|
139
|
+
# skip features that do not have stats
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# ensure only the first num_cams is used
|
|
143
|
+
if new_key.startswith("camera"):
|
|
144
|
+
cam_idx = int(new_key[len("camera") :])
|
|
145
|
+
if cam_idx >= self.cfg.num_cams:
|
|
146
|
+
continue
|
|
147
|
+
if key in stats:
|
|
148
|
+
standard_stats[new_key] = stats[key]
|
|
149
|
+
else:
|
|
150
|
+
raise KeyError(f"Key '{key}' not found in stats. Available keys: {list(stats.keys())}")
|
|
151
|
+
|
|
152
|
+
# pad state and action vectors
|
|
153
|
+
for stat in standard_stats["state"]:
|
|
154
|
+
if stat in ["mean", "std", "min", "max"]:
|
|
155
|
+
standard_stats["state"][stat] = pad_vector(
|
|
156
|
+
standard_stats["state"][stat], self.cfg.max_state_dim
|
|
157
|
+
)
|
|
158
|
+
standard_stats["actions"][stat] = pad_vector(
|
|
159
|
+
standard_stats["actions"][stat], self.cfg.max_action_dim
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# pad missing cameras
|
|
163
|
+
for cam_idx in range(self.cfg.num_cams):
|
|
164
|
+
if f"camera{cam_idx}" in standard_stats:
|
|
165
|
+
continue
|
|
166
|
+
standard_stats[f"camera{cam_idx}"] = {
|
|
167
|
+
"min": np.zeros((3, 1, 1), dtype=np.float32),
|
|
168
|
+
"max": np.ones((3, 1, 1), dtype=np.float32),
|
|
169
|
+
"mean": np.zeros((3, 1, 1), dtype=np.float32),
|
|
170
|
+
"std": np.zeros((3, 1, 1), dtype=np.float32),
|
|
171
|
+
"count": np.array(
|
|
172
|
+
standard_stats["state"]["count"]
|
|
173
|
+
), # create a copy in case this gets modified
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
# check for missing keys
|
|
177
|
+
for data in standard_stats:
|
|
178
|
+
missing_keys = {"mean", "std", "min", "max"} - standard_stats[data].keys()
|
|
179
|
+
if missing_keys:
|
|
180
|
+
raise KeyError(
|
|
181
|
+
f"The dataset {repo_id} is missing required statistics: {', '.join(sorted(missing_keys))}"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
return standard_stats
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def features(self) -> dict[str, dict]:
|
|
188
|
+
"""Return standard data format"""
|
|
189
|
+
features = {
|
|
190
|
+
"state": {
|
|
191
|
+
"shape": (self.cfg.max_state_dim,),
|
|
192
|
+
"dtype": "float32",
|
|
193
|
+
},
|
|
194
|
+
"actions": {
|
|
195
|
+
"shape": (self.cfg.max_action_dim,),
|
|
196
|
+
"dtype": "float32",
|
|
197
|
+
},
|
|
198
|
+
}
|
|
199
|
+
# add camera features
|
|
200
|
+
for i in range(self.cfg.num_cams):
|
|
201
|
+
features[f"camera{i}"] = {
|
|
202
|
+
"shape": (3, self.cfg.resolution[0], self.cfg.resolution[1]),
|
|
203
|
+
"dtype": "image",
|
|
204
|
+
}
|
|
205
|
+
return features
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class HierarchicalSampler(Sampler[int]):
|
|
209
|
+
r"""With-replacement sampler for a ConcatDataset that first samples a dataset according to `dataset_probs`, and then
|
|
210
|
+
samples uniformly within that dataset. This avoids multinomial over a huge number of categories (over 2^24)
|
|
211
|
+
by operating at the dataset level.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(
|
|
215
|
+
self,
|
|
216
|
+
dataset_lengths: List[int],
|
|
217
|
+
dataset_probs: List[float],
|
|
218
|
+
num_samples: int,
|
|
219
|
+
*,
|
|
220
|
+
generator: Optional[torch.Generator] = None,
|
|
221
|
+
seed: Optional[int] = None,
|
|
222
|
+
chunk_size: int = 262144,
|
|
223
|
+
):
|
|
224
|
+
super().__init__()
|
|
225
|
+
|
|
226
|
+
if len(dataset_lengths) != len(dataset_probs):
|
|
227
|
+
raise ValueError("dataset_lengths and dataset_probs must have the same length.")
|
|
228
|
+
self.num_samples = int(num_samples)
|
|
229
|
+
self.chunk_size = int(chunk_size)
|
|
230
|
+
|
|
231
|
+
lens = torch.as_tensor(dataset_lengths, dtype=torch.long)
|
|
232
|
+
probs = torch.as_tensor(dataset_probs, dtype=torch.double)
|
|
233
|
+
|
|
234
|
+
if (lens < 0).any():
|
|
235
|
+
raise ValueError("dataset_lengths must be non-negative.")
|
|
236
|
+
|
|
237
|
+
# Offsets for mapping local indices to global ConcatDataset indices
|
|
238
|
+
self._full_offsets = torch.zeros(len(lens), dtype=torch.long)
|
|
239
|
+
if len(lens) > 0:
|
|
240
|
+
self._full_offsets[1:] = lens.cumsum(0)[:-1]
|
|
241
|
+
|
|
242
|
+
# Keep only non-empty datasets with positive probability
|
|
243
|
+
valid_mask = (lens > 0) & (probs > 0)
|
|
244
|
+
if not bool(valid_mask.any()):
|
|
245
|
+
raise ValueError("All datasets are empty or have zero probability.")
|
|
246
|
+
|
|
247
|
+
self._valid_ids = torch.nonzero(valid_mask, as_tuple=False).flatten()
|
|
248
|
+
self._valid_lens = lens[self._valid_ids]
|
|
249
|
+
valid_probs = probs[self._valid_ids]
|
|
250
|
+
self._valid_probs = (valid_probs / valid_probs.sum()).to(dtype=torch.double)
|
|
251
|
+
|
|
252
|
+
self._num_valid = int(self._valid_ids.numel())
|
|
253
|
+
self._gen = generator if generator is not None else torch.Generator()
|
|
254
|
+
if seed is not None:
|
|
255
|
+
self._gen.manual_seed(int(seed))
|
|
256
|
+
|
|
257
|
+
def __len__(self) -> int:
|
|
258
|
+
return self.num_samples
|
|
259
|
+
|
|
260
|
+
def __iter__(self):
|
|
261
|
+
# Generate indices in memory-friendly chunks
|
|
262
|
+
total = self.num_samples
|
|
263
|
+
cs = self.chunk_size
|
|
264
|
+
for start in range(0, total, cs):
|
|
265
|
+
m = min(cs, total - start)
|
|
266
|
+
|
|
267
|
+
# Choose dataset ids according to probs (over valid ids only)
|
|
268
|
+
ds_choices_valid = torch.multinomial(self._valid_probs, m, replacement=True, generator=self._gen)
|
|
269
|
+
|
|
270
|
+
# For each chosen dataset, draw uniform local indices and map to global indices
|
|
271
|
+
out = torch.empty(m, dtype=torch.long)
|
|
272
|
+
for k in range(self._num_valid):
|
|
273
|
+
mask = ds_choices_valid == k
|
|
274
|
+
k_count = int(mask.sum().item())
|
|
275
|
+
if k_count == 0:
|
|
276
|
+
continue
|
|
277
|
+
local_idx = torch.randint(0, int(self._valid_lens[k].item()), (k_count,), generator=self._gen)
|
|
278
|
+
orig_ds_id = int(self._valid_ids[k].item())
|
|
279
|
+
out[mask] = local_idx + self._full_offsets[orig_ds_id]
|
|
280
|
+
|
|
281
|
+
# Yield one by one to conform to Sampler API
|
|
282
|
+
for idx in out.tolist():
|
|
283
|
+
yield int(idx)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class WeightedDatasetMixture:
|
|
287
|
+
"""
|
|
288
|
+
A class to combine multiple PyTorch Datasets and create a DataLoader
|
|
289
|
+
that samples from them according to specified weightings.
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(
|
|
293
|
+
self,
|
|
294
|
+
cfg: TrainPipelineConfig,
|
|
295
|
+
datasets: List[BaseDataset],
|
|
296
|
+
dataset_weights: List[float],
|
|
297
|
+
action_freq: float,
|
|
298
|
+
):
|
|
299
|
+
"""
|
|
300
|
+
Initializes the WeightedDatasetMixture.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
cfg (TrainPipelineConfig): Configuration for the training pipeline.
|
|
304
|
+
datasets (List[Dataset]): A list of PyTorch Dataset objects.
|
|
305
|
+
dataset_weights (List[float]): A list of weights corresponding to each dataset.
|
|
306
|
+
These determine the relative sampling frequency.
|
|
307
|
+
"""
|
|
308
|
+
if not datasets:
|
|
309
|
+
raise ValueError("The list of datasets cannot be empty.")
|
|
310
|
+
if len(datasets) != len(dataset_weights):
|
|
311
|
+
raise ValueError("The number of datasets must match the number of dataset_weights.")
|
|
312
|
+
if any(w < 0 for w in dataset_weights):
|
|
313
|
+
raise ValueError("Dataset weights must be non-negative.")
|
|
314
|
+
if sum(dataset_weights) == 0 and any(len(ds) > 0 for ds in datasets):
|
|
315
|
+
# If all weights are zero, but there's data, sampler will fail.
|
|
316
|
+
# If all datasets are empty, sum of weights being zero is fine.
|
|
317
|
+
logging.warning(
|
|
318
|
+
"Warning: All dataset weights are zero. The sampler might not behave as expected if datasets have samples."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
self.cfg = cfg
|
|
322
|
+
self.datasets = datasets
|
|
323
|
+
self.dataset_weights = dataset_weights
|
|
324
|
+
self.action_freq = action_freq # Frequency used for resampling action output
|
|
325
|
+
self.dataset_names = [type(ds).__name__ + f"_{i}" for i, ds in enumerate(datasets)] # For logging
|
|
326
|
+
|
|
327
|
+
logging.info("Initializing WeightedDatasetMixture...")
|
|
328
|
+
self._log_dataset_info()
|
|
329
|
+
|
|
330
|
+
self.concatenated_dataset: ConcatDataset = ConcatDataset(datasets)
|
|
331
|
+
logging.info(f"Total length of concatenated dataset: {len(self.concatenated_dataset)}")
|
|
332
|
+
|
|
333
|
+
self.sample_weights: torch.Tensor = self._calculate_sample_weights()
|
|
334
|
+
if self.sample_weights is None and len(self.concatenated_dataset) > 0:
|
|
335
|
+
raise ValueError("Sample weights could not be calculated, but concatenated dataset is not empty.")
|
|
336
|
+
elif self.sample_weights is not None and len(self.sample_weights) != len(self.concatenated_dataset):
|
|
337
|
+
raise ValueError(
|
|
338
|
+
f"Length of sample_weights ({len(self.sample_weights)}) "
|
|
339
|
+
f"must match concatenated_dataset length ({len(self.concatenated_dataset)})."
|
|
340
|
+
)
|
|
341
|
+
logging.info("-" * 30)
|
|
342
|
+
|
|
343
|
+
# aggregate metadata
|
|
344
|
+
if not all(hasattr(ds, "meta") and ds.meta is not None for ds in datasets):
|
|
345
|
+
raise ValueError("All datasets must have a 'meta' attribute with valid metadata.")
|
|
346
|
+
self.meta = DatasetMixtureMetadata(cfg, [ds.meta for ds in datasets], dataset_weights)
|
|
347
|
+
|
|
348
|
+
def _log_dataset_info(self) -> None:
|
|
349
|
+
"""Log information about all datasets in the mixture."""
|
|
350
|
+
logging.info("Dataset information:")
|
|
351
|
+
for i, ds in enumerate(self.datasets):
|
|
352
|
+
logging.info(f" - {self.dataset_names[i]}: Length={len(ds)}, Weight={self.dataset_weights[i]}")
|
|
353
|
+
logging.info("-" * 30)
|
|
354
|
+
|
|
355
|
+
def _calculate_sample_weights(self) -> Optional[torch.Tensor]:
|
|
356
|
+
"""Calculate the weight for each individual sample in the concatenated dataset.
|
|
357
|
+
|
|
358
|
+
Samples from datasets with higher weights or smaller sizes (for a given weight)
|
|
359
|
+
will have higher individual sample weights. Weight per sample = dataset_weight / dataset_length.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
Tensor of sample weights, or None if all datasets are empty or have zero weight.
|
|
363
|
+
|
|
364
|
+
Raises:
|
|
365
|
+
RuntimeError: If there's a mismatch between concatenated dataset length
|
|
366
|
+
and calculated sample weights.
|
|
367
|
+
"""
|
|
368
|
+
if not self.concatenated_dataset: # Handles case where all input datasets are empty
|
|
369
|
+
logging.warning("Warning: Concatenated dataset is empty. No sample weights to calculate.")
|
|
370
|
+
return None
|
|
371
|
+
|
|
372
|
+
logging.info("Calculating per-sample weights...")
|
|
373
|
+
all_sample_weights: List[float] = []
|
|
374
|
+
dataset_lengths = [len(ds) for ds in self.datasets]
|
|
375
|
+
|
|
376
|
+
for i, length in enumerate(dataset_lengths):
|
|
377
|
+
dataset_name = self.dataset_names[i]
|
|
378
|
+
current_dataset_weight = self.dataset_weights[i]
|
|
379
|
+
|
|
380
|
+
if length == 0:
|
|
381
|
+
logging.info(f" Skipping {dataset_name} (length 0).")
|
|
382
|
+
continue # Skip empty datasets
|
|
383
|
+
|
|
384
|
+
if current_dataset_weight == 0:
|
|
385
|
+
# Assign zero weight to all samples in this dataset
|
|
386
|
+
weight_per_sample = 0.0
|
|
387
|
+
logging.info(
|
|
388
|
+
f" Weight for each sample in {dataset_name} (size {length}): {weight_per_sample:.10f} (dataset weight is 0)"
|
|
389
|
+
)
|
|
390
|
+
else:
|
|
391
|
+
# Standard calculation: dataset_weight / num_samples_in_dataset
|
|
392
|
+
weight_per_sample = current_dataset_weight / length
|
|
393
|
+
logging.info(
|
|
394
|
+
f" Weight for each sample in {dataset_name} (size {length}): {weight_per_sample:.10f}"
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
all_sample_weights.extend([weight_per_sample] * length)
|
|
398
|
+
|
|
399
|
+
if not all_sample_weights: # All datasets were empty or had 0 weight
|
|
400
|
+
if len(self.concatenated_dataset) > 0: # Should not happen if logic is correct
|
|
401
|
+
raise RuntimeError(
|
|
402
|
+
"Mismatch: concatenated_dataset has samples but all_sample_weights is empty."
|
|
403
|
+
)
|
|
404
|
+
logging.warning(
|
|
405
|
+
"Warning: All datasets are effectively empty or have zero weight. Sample weights list is empty."
|
|
406
|
+
)
|
|
407
|
+
return None # No samples to weight
|
|
408
|
+
|
|
409
|
+
return torch.DoubleTensor(all_sample_weights)
|
|
410
|
+
|
|
411
|
+
def get_dataloader(self) -> DataLoader:
|
|
412
|
+
"""Create and return a PyTorch DataLoader with weighted sampling.
|
|
413
|
+
|
|
414
|
+
Uses HierarchicalSampler to first sample a dataset according to weights,
|
|
415
|
+
then uniformly sample within that dataset.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
DataLoader configured for weighted hierarchical sampling.
|
|
419
|
+
|
|
420
|
+
Raises:
|
|
421
|
+
ValueError: If no non-empty dataset has a positive sampling weight.
|
|
422
|
+
"""
|
|
423
|
+
if len(self.concatenated_dataset) == 0:
|
|
424
|
+
logging.warning("Warning: Concatenated dataset is empty. DataLoader will produce no batches.")
|
|
425
|
+
# Return an empty dataloader or raise error, depending on desired behavior.
|
|
426
|
+
# For now, let it create an empty dataloader.
|
|
427
|
+
return DataLoader(
|
|
428
|
+
self.concatenated_dataset, batch_size=self.cfg.batch_size, num_workers=self.cfg.num_workers
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# Validate there is at least one non-empty dataset with positive weight
|
|
432
|
+
if not any(len(ds) > 0 and w > 0 for ds, w in zip(self.datasets, self.dataset_weights, strict=True)):
|
|
433
|
+
logging.error("Error: No non-empty dataset has a positive sampling weight.")
|
|
434
|
+
raise ValueError("No non-empty dataset has a positive sampling weight.")
|
|
435
|
+
|
|
436
|
+
num_samples_per_epoch = len(self.concatenated_dataset)
|
|
437
|
+
logging.info("\nCreating DataLoader...")
|
|
438
|
+
logging.info(f" Batch size: {self.cfg.batch_size}")
|
|
439
|
+
logging.info(f" Samples per epoch (num_samples for sampler): {num_samples_per_epoch}")
|
|
440
|
+
|
|
441
|
+
# Hierarchical sampling: choose dataset by weight, then uniform within it (both with replacement)
|
|
442
|
+
ds_lengths = [len(ds) for ds in self.datasets]
|
|
443
|
+
sampler = HierarchicalSampler(
|
|
444
|
+
dataset_lengths=ds_lengths,
|
|
445
|
+
dataset_probs=self.dataset_weights,
|
|
446
|
+
num_samples=num_samples_per_epoch,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
dataloader = DataLoader(
|
|
450
|
+
self.concatenated_dataset,
|
|
451
|
+
batch_size=self.cfg.dataloader_batch_size,
|
|
452
|
+
sampler=sampler,
|
|
453
|
+
num_workers=self.cfg.num_workers,
|
|
454
|
+
pin_memory=torch.cuda.is_available(),
|
|
455
|
+
drop_last=False,
|
|
456
|
+
prefetch_factor=self.cfg.prefetch_factor,
|
|
457
|
+
)
|
|
458
|
+
logging.info("DataLoader created successfully.")
|
|
459
|
+
logging.info("-" * 30)
|
|
460
|
+
return dataloader
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""Factory functions for creating datasets and dataset mixtures.
|
|
19
|
+
|
|
20
|
+
This module provides factory functions to create individual datasets and
|
|
21
|
+
weighted dataset mixtures from configuration objects. It handles the setup
|
|
22
|
+
of delta timestamps, image transforms, and metadata configuration before
|
|
23
|
+
instantiating datasets.
|
|
24
|
+
|
|
25
|
+
The factory supports two types of datasets:
|
|
26
|
+
1. LeRobot datasets: Standard robot learning datasets loaded from HuggingFace
|
|
27
|
+
repositories with configurable delta timestamps for temporal alignment.
|
|
28
|
+
2. Grounding datasets: Vision-language grounding datasets (CLEVR, COCO-QA,
|
|
29
|
+
PIXMO, VSR, etc.) for multimodal learning tasks.
|
|
30
|
+
|
|
31
|
+
Key Features:
|
|
32
|
+
- Delta timestamp resolution: Automatically configures temporal offsets
|
|
33
|
+
for features based on policy latency settings (action decoder and
|
|
34
|
+
cloud VLM latencies).
|
|
35
|
+
- Image transform support: Applies configurable image transformations
|
|
36
|
+
during dataset creation.
|
|
37
|
+
- Imagenet stats override: Optionally replaces dataset statistics with
|
|
38
|
+
ImageNet normalization statistics for camera features.
|
|
39
|
+
- Grounding dataset registration: Supports extensible grounding dataset
|
|
40
|
+
registration through side-effect imports.
|
|
41
|
+
|
|
42
|
+
Functions:
|
|
43
|
+
make_dataset: Creates a single dataset instance from a DatasetConfig,
|
|
44
|
+
handling delta timestamp setup, image transforms, and metadata
|
|
45
|
+
configuration.
|
|
46
|
+
make_dataset_mixture: Creates a WeightedDatasetMixture from a
|
|
47
|
+
TrainPipelineConfig containing multiple dataset configurations.
|
|
48
|
+
resolve_delta_timestamps: Resolves delta timestamps configuration based
|
|
49
|
+
on TrainPipelineConfig settings, mapping features to temporal groups.
|
|
50
|
+
|
|
51
|
+
Constants:
|
|
52
|
+
IMAGENET_STATS: ImageNet normalization statistics (mean, std, min, max)
|
|
53
|
+
used for camera feature normalization when use_imagenet_stats is enabled.
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
Create a single dataset:
|
|
57
|
+
>>> dataset = make_dataset(dataset_cfg, train_cfg, return_advantage_input=False)
|
|
58
|
+
|
|
59
|
+
Create a dataset mixture:
|
|
60
|
+
>>> mixture = make_dataset_mixture(train_cfg, return_advantage_input=False)
|
|
61
|
+
>>> dataloader = mixture.get_dataloader()
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
import numpy as np
|
|
65
|
+
|
|
66
|
+
# NOTE: Don't delete; imported for side effects.
|
|
67
|
+
import opentau.datasets.grounding.clevr # noqa: F401
|
|
68
|
+
import opentau.datasets.grounding.cocoqa # noqa: F401
|
|
69
|
+
import opentau.datasets.grounding.dummy # noqa: F401
|
|
70
|
+
import opentau.datasets.grounding.pixmo # noqa: F401
|
|
71
|
+
import opentau.datasets.grounding.vsr # noqa: F401
|
|
72
|
+
from opentau import available_grounding_datasets
|
|
73
|
+
from opentau.configs.default import DatasetConfig
|
|
74
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
75
|
+
from opentau.datasets.dataset_mixture import WeightedDatasetMixture
|
|
76
|
+
from opentau.datasets.lerobot_dataset import (
|
|
77
|
+
BaseDataset,
|
|
78
|
+
LeRobotDataset,
|
|
79
|
+
LeRobotDatasetMetadata,
|
|
80
|
+
)
|
|
81
|
+
from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING
|
|
82
|
+
from opentau.datasets.transforms import ImageTransforms
|
|
83
|
+
|
|
84
|
+
IMAGENET_STATS = {
|
|
85
|
+
"min": [[[0.0]], [[0.0]], [[0.0]]], # (c,1,1)
|
|
86
|
+
"max": [[[1.0]], [[1.0]], [[1.0]]], # (c,1,1)
|
|
87
|
+
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
|
88
|
+
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def resolve_delta_timestamps(
|
|
93
|
+
cfg: TrainPipelineConfig, dataset_cfg: DatasetConfig, ds_meta: LeRobotDatasetMetadata
|
|
94
|
+
) -> tuple:
|
|
95
|
+
"""Resolves delta_timestamps by based on TrainPipelineConfig.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
cfg (TrainPipelineConfig): The TrainPipelineConfig to read delta_indices from.
|
|
99
|
+
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
|
100
|
+
delta_timestamps against.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
A 2-tuple containing:
|
|
104
|
+
|
|
105
|
+
- At index 0, a 4-tuple containing delta timestamps mean, std, lower, and upper bounds for each group.
|
|
106
|
+
- At index 1, a dictionary mapping feature names to their corresponding group and index.
|
|
107
|
+
|
|
108
|
+
The delta timestamps and group mapping should follow the structure expected by LeRobotDataset.
|
|
109
|
+
"""
|
|
110
|
+
group = "input_group"
|
|
111
|
+
feature2group = {}
|
|
112
|
+
# Delta timestamps are in seconds, and negative because they represent past timestamps.
|
|
113
|
+
# Hence, lower and upper bounds correspond to -upper and -lower.
|
|
114
|
+
delta_timestamps = {group: [-cfg.policy.action_decoder_latency_mean, -cfg.policy.cloud_vlm_latency_mean]}
|
|
115
|
+
delta_timestamps_std = {group: [cfg.policy.action_decoder_latency_std, cfg.policy.cloud_vlm_latency_std]}
|
|
116
|
+
delta_timestamps_lower = {
|
|
117
|
+
group: [-cfg.policy.action_decoder_latency_upper, -cfg.policy.cloud_vlm_latency_upper]
|
|
118
|
+
}
|
|
119
|
+
delta_timestamps_upper = {
|
|
120
|
+
group: [-cfg.policy.action_decoder_latency_lower, -cfg.policy.cloud_vlm_latency_lower]
|
|
121
|
+
}
|
|
122
|
+
action_freq = cfg.dataset_mixture.action_freq
|
|
123
|
+
|
|
124
|
+
name_map = DATA_FEATURES_NAME_MAPPING[dataset_cfg.repo_id]
|
|
125
|
+
reverse_name_map = {v: k for k, v in name_map.items()}
|
|
126
|
+
for key in ds_meta.features:
|
|
127
|
+
if key not in reverse_name_map:
|
|
128
|
+
continue # only process camera, state, and action features
|
|
129
|
+
|
|
130
|
+
standard_key = reverse_name_map[key]
|
|
131
|
+
if standard_key == "actions" and cfg.policy.action_delta_indices is not None:
|
|
132
|
+
delta_timestamps[key] = [i / action_freq for i in cfg.policy.action_delta_indices]
|
|
133
|
+
feature2group[key] = (key, None)
|
|
134
|
+
if "camera" in standard_key:
|
|
135
|
+
# Index 0 corresponds to action decoder latency and index 1 to cloud VLM latency.
|
|
136
|
+
# Pick both indices. `_to_standard_data_format()` will separate the two.
|
|
137
|
+
feature2group[key] = (group, [0, 1])
|
|
138
|
+
elif standard_key == "state":
|
|
139
|
+
# Pick index 0, which corresponds to latency of action decoder, and squeeze it to a scalar.
|
|
140
|
+
feature2group[key] = (group, 0)
|
|
141
|
+
|
|
142
|
+
return (
|
|
143
|
+
delta_timestamps,
|
|
144
|
+
delta_timestamps_std,
|
|
145
|
+
delta_timestamps_lower,
|
|
146
|
+
delta_timestamps_upper,
|
|
147
|
+
), feature2group
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def make_dataset(
|
|
151
|
+
cfg: DatasetConfig,
|
|
152
|
+
train_cfg: TrainPipelineConfig,
|
|
153
|
+
return_advantage_input: bool = False,
|
|
154
|
+
) -> BaseDataset:
|
|
155
|
+
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
cfg (DatasetConfig): A DatasetConfig used to create a LeRobotDataset.
|
|
159
|
+
train_cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
|
160
|
+
return_advantage_input (bool): Whether the created dataset includes advantage inputs including "success",
|
|
161
|
+
"episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
BaseDataset
|
|
168
|
+
"""
|
|
169
|
+
image_transforms = ImageTransforms(cfg.image_transforms) if cfg.image_transforms.enable else None
|
|
170
|
+
|
|
171
|
+
if isinstance(cfg.grounding, str) + isinstance(cfg.repo_id, str) != 1:
|
|
172
|
+
raise ValueError("Exactly one of `cfg.grounding` and `cfg.repo_id` should be provided.")
|
|
173
|
+
|
|
174
|
+
if isinstance(cfg.grounding, str):
|
|
175
|
+
ds_cls = available_grounding_datasets.get(cfg.grounding)
|
|
176
|
+
if ds_cls is None:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Unknown grounding dataset '{cfg.grounding}'. "
|
|
179
|
+
f"Supported datasets are: {available_grounding_datasets.keys()}"
|
|
180
|
+
)
|
|
181
|
+
# TODO support dataset-specific arg / kwargs
|
|
182
|
+
dataset = ds_cls(train_cfg)
|
|
183
|
+
elif isinstance(cfg.repo_id, str):
|
|
184
|
+
ds_meta = LeRobotDatasetMetadata(cfg.repo_id, root=cfg.root, revision=cfg.revision)
|
|
185
|
+
(dt_mean, dt_std, dt_lower, dt_upper), f2g = resolve_delta_timestamps(train_cfg, cfg, ds_meta)
|
|
186
|
+
dataset = LeRobotDataset(
|
|
187
|
+
train_cfg,
|
|
188
|
+
cfg.repo_id,
|
|
189
|
+
root=cfg.root,
|
|
190
|
+
episodes=cfg.episodes,
|
|
191
|
+
delta_timestamps=dt_mean,
|
|
192
|
+
delta_timestamps_std=dt_std,
|
|
193
|
+
delta_timestamps_lower=dt_lower,
|
|
194
|
+
delta_timestamps_upper=dt_upper,
|
|
195
|
+
feature2group=f2g,
|
|
196
|
+
image_transforms=image_transforms,
|
|
197
|
+
revision=cfg.revision,
|
|
198
|
+
video_backend=cfg.video_backend,
|
|
199
|
+
image_resample_strategy=train_cfg.dataset_mixture.image_resample_strategy,
|
|
200
|
+
vector_resample_strategy=train_cfg.dataset_mixture.vector_resample_strategy,
|
|
201
|
+
return_advantage_input=return_advantage_input,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# TODO grounding datasets implement stats in original feature names, but camera_keys are standardized names
|
|
205
|
+
if not isinstance(cfg.grounding, str) and "dummy" not in cfg.repo_id and cfg.use_imagenet_stats:
|
|
206
|
+
for key in dataset.meta.camera_keys:
|
|
207
|
+
for stats_type, stats in IMAGENET_STATS.items():
|
|
208
|
+
if key not in dataset.meta.stats:
|
|
209
|
+
dataset.meta.stats[key] = {}
|
|
210
|
+
dataset.meta.stats[key][stats_type] = np.array(stats, dtype=np.float32)
|
|
211
|
+
|
|
212
|
+
return dataset
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def make_dataset_mixture(
|
|
216
|
+
cfg: TrainPipelineConfig, return_advantage_input: bool = False
|
|
217
|
+
) -> WeightedDatasetMixture:
|
|
218
|
+
"""Creates a dataset mixture from the provided TrainPipelineConfig.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
cfg (TrainPipelineConfig): The configuration containing the datasets to mix.
|
|
222
|
+
return_advantage_input (bool): Whether the datasets should return advantage inputs including "success",
|
|
223
|
+
"episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
WeightedDatasetMixture: An instance of WeightedDatasetMixture containing the datasets.
|
|
227
|
+
"""
|
|
228
|
+
datasets = [
|
|
229
|
+
make_dataset(dataset_cfg, cfg, return_advantage_input=return_advantage_input)
|
|
230
|
+
for dataset_cfg in cfg.dataset_mixture.datasets
|
|
231
|
+
]
|
|
232
|
+
return WeightedDatasetMixture(cfg, datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq)
|