opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,460 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Weighted dataset mixture for combining multiple datasets with controlled sampling.
16
+
17
+ This module provides functionality to combine multiple PyTorch datasets into a
18
+ single weighted mixture, enabling training on heterogeneous datasets with
19
+ controlled sampling proportions. It supports hierarchical sampling strategies
20
+ that efficiently handle large-scale dataset combinations while maintaining
21
+ memory efficiency.
22
+
23
+ The module implements a two-level sampling approach:
24
+ 1. Dataset-level sampling: Selects which dataset to sample from based on
25
+ specified weights.
26
+ 2. Sample-level sampling: Uniformly samples within the selected dataset.
27
+
28
+ This hierarchical approach avoids expensive multinomial sampling over millions
29
+ of individual samples by operating at the dataset level, making it scalable
30
+ for large dataset mixtures.
31
+
32
+ Key Features:
33
+ - Weighted sampling: Control relative sampling frequency of different
34
+ datasets through configurable weights.
35
+ - Memory-efficient sampling: Hierarchical sampler processes samples in
36
+ chunks to minimize memory overhead.
37
+ - Metadata aggregation: Automatically aggregates and standardizes metadata
38
+ from multiple datasets, including statistics normalization and feature
39
+ name mapping.
40
+ - Format standardization: Converts dataset-specific feature formats to a
41
+ common standard format, handling vector padding and missing cameras.
42
+
43
+ Classes:
44
+ WeightedDatasetMixture: Main class for combining multiple datasets with
45
+ weighted sampling. Creates concatenated datasets and provides DataLoader
46
+ with hierarchical sampling.
47
+ HierarchicalSampler: Custom PyTorch sampler that implements two-level
48
+ weighted sampling (dataset selection, then uniform sample selection).
49
+ DatasetMixtureMetadata: Aggregates metadata from multiple datasets,
50
+ standardizes feature names, pads vectors, and combines statistics.
51
+
52
+ Functions:
53
+ pad_vector: Pads the last dimension of a vector to a target size with zeros.
54
+
55
+ Example:
56
+ Create a dataset mixture with two datasets:
57
+ >>> datasets = [dataset1, dataset2]
58
+ >>> weights = [0.7, 0.3] # 70% from dataset1, 30% from dataset2
59
+ >>> mixture = WeightedDatasetMixture(cfg, datasets, weights, action_freq=30.0)
60
+ >>> dataloader = mixture.get_dataloader()
61
+ """
62
+
63
+ import logging
64
+ from typing import List, Optional
65
+
66
+ import numpy as np
67
+ import torch
68
+ from torch.utils.data import ConcatDataset, DataLoader, Sampler
69
+
70
+ from opentau.configs.train import TrainPipelineConfig
71
+ from opentau.datasets.compute_stats import aggregate_stats
72
+ from opentau.datasets.lerobot_dataset import BaseDataset, DatasetMetadata
73
+ from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING
74
+
75
+
76
+ def pad_vector(vector: np.ndarray, new_dim: int) -> np.ndarray:
77
+ """Pad the last dimension of a vector to a target size with zeros.
78
+
79
+ Args:
80
+ vector: Input numpy array to pad.
81
+ new_dim: Target size for the last dimension.
82
+
83
+ Returns:
84
+ Padded array with the last dimension expanded to new_dim. If the
85
+ vector already has the target dimension, returns it unchanged.
86
+ """
87
+ if vector.shape[-1] == new_dim:
88
+ return vector
89
+ shape = list(vector.shape)
90
+ current_dim = shape[-1]
91
+ shape[-1] = new_dim
92
+ new_vector = np.zeros(shape, dtype=vector.dtype)
93
+ new_vector[..., :current_dim] = vector
94
+ return new_vector
95
+
96
+
97
+ class DatasetMixtureMetadata:
98
+ """A class to hold metadata for a mixture of datasets.
99
+
100
+ This is used to aggregate metadata from multiple datasets into a single object.
101
+ """
102
+
103
+ def __init__(
104
+ self, cfg: TrainPipelineConfig, metadatas: List[DatasetMetadata], dataset_weights: List[float]
105
+ ):
106
+ self.cfg = cfg
107
+
108
+ # convert each metadata stats to the standard data format
109
+ for metadata in metadatas:
110
+ metadata.stats = self._to_standard_data_format(metadata.repo_id, metadata.stats)
111
+
112
+ self.stats = aggregate_stats([metadata.stats for metadata in metadatas], weights=dataset_weights)
113
+
114
+ def _to_standard_data_format(
115
+ self, repo_id: str, stats: dict[str, dict[str, np.ndarray]]
116
+ ) -> dict[str, dict[str, np.ndarray]]:
117
+ """Convert statistics to the standard data format.
118
+
119
+ Maps feature names from dataset-specific format to standard format,
120
+ pads state and action vectors, and ensures all required cameras are present.
121
+
122
+ Args:
123
+ repo_id: Repository ID used to look up feature name mapping.
124
+ stats: Statistics dictionary with dataset-specific feature names.
125
+
126
+ Returns:
127
+ Statistics dictionary with standard feature names and padded vectors.
128
+
129
+ Raises:
130
+ KeyError: If a required feature is missing from stats or if required
131
+ statistics (mean, std, min, max) are missing.
132
+ """
133
+ name_map = DATA_FEATURES_NAME_MAPPING[repo_id]
134
+ features_without_stats = ["prompt", "response", "advantage"]
135
+
136
+ standard_stats = {}
137
+ for new_key, key in name_map.items():
138
+ if new_key in features_without_stats:
139
+ # skip features that do not have stats
140
+ continue
141
+
142
+ # ensure only the first num_cams is used
143
+ if new_key.startswith("camera"):
144
+ cam_idx = int(new_key[len("camera") :])
145
+ if cam_idx >= self.cfg.num_cams:
146
+ continue
147
+ if key in stats:
148
+ standard_stats[new_key] = stats[key]
149
+ else:
150
+ raise KeyError(f"Key '{key}' not found in stats. Available keys: {list(stats.keys())}")
151
+
152
+ # pad state and action vectors
153
+ for stat in standard_stats["state"]:
154
+ if stat in ["mean", "std", "min", "max"]:
155
+ standard_stats["state"][stat] = pad_vector(
156
+ standard_stats["state"][stat], self.cfg.max_state_dim
157
+ )
158
+ standard_stats["actions"][stat] = pad_vector(
159
+ standard_stats["actions"][stat], self.cfg.max_action_dim
160
+ )
161
+
162
+ # pad missing cameras
163
+ for cam_idx in range(self.cfg.num_cams):
164
+ if f"camera{cam_idx}" in standard_stats:
165
+ continue
166
+ standard_stats[f"camera{cam_idx}"] = {
167
+ "min": np.zeros((3, 1, 1), dtype=np.float32),
168
+ "max": np.ones((3, 1, 1), dtype=np.float32),
169
+ "mean": np.zeros((3, 1, 1), dtype=np.float32),
170
+ "std": np.zeros((3, 1, 1), dtype=np.float32),
171
+ "count": np.array(
172
+ standard_stats["state"]["count"]
173
+ ), # create a copy in case this gets modified
174
+ }
175
+
176
+ # check for missing keys
177
+ for data in standard_stats:
178
+ missing_keys = {"mean", "std", "min", "max"} - standard_stats[data].keys()
179
+ if missing_keys:
180
+ raise KeyError(
181
+ f"The dataset {repo_id} is missing required statistics: {', '.join(sorted(missing_keys))}"
182
+ )
183
+
184
+ return standard_stats
185
+
186
+ @property
187
+ def features(self) -> dict[str, dict]:
188
+ """Return standard data format"""
189
+ features = {
190
+ "state": {
191
+ "shape": (self.cfg.max_state_dim,),
192
+ "dtype": "float32",
193
+ },
194
+ "actions": {
195
+ "shape": (self.cfg.max_action_dim,),
196
+ "dtype": "float32",
197
+ },
198
+ }
199
+ # add camera features
200
+ for i in range(self.cfg.num_cams):
201
+ features[f"camera{i}"] = {
202
+ "shape": (3, self.cfg.resolution[0], self.cfg.resolution[1]),
203
+ "dtype": "image",
204
+ }
205
+ return features
206
+
207
+
208
+ class HierarchicalSampler(Sampler[int]):
209
+ r"""With-replacement sampler for a ConcatDataset that first samples a dataset according to `dataset_probs`, and then
210
+ samples uniformly within that dataset. This avoids multinomial over a huge number of categories (over 2^24)
211
+ by operating at the dataset level.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ dataset_lengths: List[int],
217
+ dataset_probs: List[float],
218
+ num_samples: int,
219
+ *,
220
+ generator: Optional[torch.Generator] = None,
221
+ seed: Optional[int] = None,
222
+ chunk_size: int = 262144,
223
+ ):
224
+ super().__init__()
225
+
226
+ if len(dataset_lengths) != len(dataset_probs):
227
+ raise ValueError("dataset_lengths and dataset_probs must have the same length.")
228
+ self.num_samples = int(num_samples)
229
+ self.chunk_size = int(chunk_size)
230
+
231
+ lens = torch.as_tensor(dataset_lengths, dtype=torch.long)
232
+ probs = torch.as_tensor(dataset_probs, dtype=torch.double)
233
+
234
+ if (lens < 0).any():
235
+ raise ValueError("dataset_lengths must be non-negative.")
236
+
237
+ # Offsets for mapping local indices to global ConcatDataset indices
238
+ self._full_offsets = torch.zeros(len(lens), dtype=torch.long)
239
+ if len(lens) > 0:
240
+ self._full_offsets[1:] = lens.cumsum(0)[:-1]
241
+
242
+ # Keep only non-empty datasets with positive probability
243
+ valid_mask = (lens > 0) & (probs > 0)
244
+ if not bool(valid_mask.any()):
245
+ raise ValueError("All datasets are empty or have zero probability.")
246
+
247
+ self._valid_ids = torch.nonzero(valid_mask, as_tuple=False).flatten()
248
+ self._valid_lens = lens[self._valid_ids]
249
+ valid_probs = probs[self._valid_ids]
250
+ self._valid_probs = (valid_probs / valid_probs.sum()).to(dtype=torch.double)
251
+
252
+ self._num_valid = int(self._valid_ids.numel())
253
+ self._gen = generator if generator is not None else torch.Generator()
254
+ if seed is not None:
255
+ self._gen.manual_seed(int(seed))
256
+
257
+ def __len__(self) -> int:
258
+ return self.num_samples
259
+
260
+ def __iter__(self):
261
+ # Generate indices in memory-friendly chunks
262
+ total = self.num_samples
263
+ cs = self.chunk_size
264
+ for start in range(0, total, cs):
265
+ m = min(cs, total - start)
266
+
267
+ # Choose dataset ids according to probs (over valid ids only)
268
+ ds_choices_valid = torch.multinomial(self._valid_probs, m, replacement=True, generator=self._gen)
269
+
270
+ # For each chosen dataset, draw uniform local indices and map to global indices
271
+ out = torch.empty(m, dtype=torch.long)
272
+ for k in range(self._num_valid):
273
+ mask = ds_choices_valid == k
274
+ k_count = int(mask.sum().item())
275
+ if k_count == 0:
276
+ continue
277
+ local_idx = torch.randint(0, int(self._valid_lens[k].item()), (k_count,), generator=self._gen)
278
+ orig_ds_id = int(self._valid_ids[k].item())
279
+ out[mask] = local_idx + self._full_offsets[orig_ds_id]
280
+
281
+ # Yield one by one to conform to Sampler API
282
+ for idx in out.tolist():
283
+ yield int(idx)
284
+
285
+
286
+ class WeightedDatasetMixture:
287
+ """
288
+ A class to combine multiple PyTorch Datasets and create a DataLoader
289
+ that samples from them according to specified weightings.
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ cfg: TrainPipelineConfig,
295
+ datasets: List[BaseDataset],
296
+ dataset_weights: List[float],
297
+ action_freq: float,
298
+ ):
299
+ """
300
+ Initializes the WeightedDatasetMixture.
301
+
302
+ Args:
303
+ cfg (TrainPipelineConfig): Configuration for the training pipeline.
304
+ datasets (List[Dataset]): A list of PyTorch Dataset objects.
305
+ dataset_weights (List[float]): A list of weights corresponding to each dataset.
306
+ These determine the relative sampling frequency.
307
+ """
308
+ if not datasets:
309
+ raise ValueError("The list of datasets cannot be empty.")
310
+ if len(datasets) != len(dataset_weights):
311
+ raise ValueError("The number of datasets must match the number of dataset_weights.")
312
+ if any(w < 0 for w in dataset_weights):
313
+ raise ValueError("Dataset weights must be non-negative.")
314
+ if sum(dataset_weights) == 0 and any(len(ds) > 0 for ds in datasets):
315
+ # If all weights are zero, but there's data, sampler will fail.
316
+ # If all datasets are empty, sum of weights being zero is fine.
317
+ logging.warning(
318
+ "Warning: All dataset weights are zero. The sampler might not behave as expected if datasets have samples."
319
+ )
320
+
321
+ self.cfg = cfg
322
+ self.datasets = datasets
323
+ self.dataset_weights = dataset_weights
324
+ self.action_freq = action_freq # Frequency used for resampling action output
325
+ self.dataset_names = [type(ds).__name__ + f"_{i}" for i, ds in enumerate(datasets)] # For logging
326
+
327
+ logging.info("Initializing WeightedDatasetMixture...")
328
+ self._log_dataset_info()
329
+
330
+ self.concatenated_dataset: ConcatDataset = ConcatDataset(datasets)
331
+ logging.info(f"Total length of concatenated dataset: {len(self.concatenated_dataset)}")
332
+
333
+ self.sample_weights: torch.Tensor = self._calculate_sample_weights()
334
+ if self.sample_weights is None and len(self.concatenated_dataset) > 0:
335
+ raise ValueError("Sample weights could not be calculated, but concatenated dataset is not empty.")
336
+ elif self.sample_weights is not None and len(self.sample_weights) != len(self.concatenated_dataset):
337
+ raise ValueError(
338
+ f"Length of sample_weights ({len(self.sample_weights)}) "
339
+ f"must match concatenated_dataset length ({len(self.concatenated_dataset)})."
340
+ )
341
+ logging.info("-" * 30)
342
+
343
+ # aggregate metadata
344
+ if not all(hasattr(ds, "meta") and ds.meta is not None for ds in datasets):
345
+ raise ValueError("All datasets must have a 'meta' attribute with valid metadata.")
346
+ self.meta = DatasetMixtureMetadata(cfg, [ds.meta for ds in datasets], dataset_weights)
347
+
348
+ def _log_dataset_info(self) -> None:
349
+ """Log information about all datasets in the mixture."""
350
+ logging.info("Dataset information:")
351
+ for i, ds in enumerate(self.datasets):
352
+ logging.info(f" - {self.dataset_names[i]}: Length={len(ds)}, Weight={self.dataset_weights[i]}")
353
+ logging.info("-" * 30)
354
+
355
+ def _calculate_sample_weights(self) -> Optional[torch.Tensor]:
356
+ """Calculate the weight for each individual sample in the concatenated dataset.
357
+
358
+ Samples from datasets with higher weights or smaller sizes (for a given weight)
359
+ will have higher individual sample weights. Weight per sample = dataset_weight / dataset_length.
360
+
361
+ Returns:
362
+ Tensor of sample weights, or None if all datasets are empty or have zero weight.
363
+
364
+ Raises:
365
+ RuntimeError: If there's a mismatch between concatenated dataset length
366
+ and calculated sample weights.
367
+ """
368
+ if not self.concatenated_dataset: # Handles case where all input datasets are empty
369
+ logging.warning("Warning: Concatenated dataset is empty. No sample weights to calculate.")
370
+ return None
371
+
372
+ logging.info("Calculating per-sample weights...")
373
+ all_sample_weights: List[float] = []
374
+ dataset_lengths = [len(ds) for ds in self.datasets]
375
+
376
+ for i, length in enumerate(dataset_lengths):
377
+ dataset_name = self.dataset_names[i]
378
+ current_dataset_weight = self.dataset_weights[i]
379
+
380
+ if length == 0:
381
+ logging.info(f" Skipping {dataset_name} (length 0).")
382
+ continue # Skip empty datasets
383
+
384
+ if current_dataset_weight == 0:
385
+ # Assign zero weight to all samples in this dataset
386
+ weight_per_sample = 0.0
387
+ logging.info(
388
+ f" Weight for each sample in {dataset_name} (size {length}): {weight_per_sample:.10f} (dataset weight is 0)"
389
+ )
390
+ else:
391
+ # Standard calculation: dataset_weight / num_samples_in_dataset
392
+ weight_per_sample = current_dataset_weight / length
393
+ logging.info(
394
+ f" Weight for each sample in {dataset_name} (size {length}): {weight_per_sample:.10f}"
395
+ )
396
+
397
+ all_sample_weights.extend([weight_per_sample] * length)
398
+
399
+ if not all_sample_weights: # All datasets were empty or had 0 weight
400
+ if len(self.concatenated_dataset) > 0: # Should not happen if logic is correct
401
+ raise RuntimeError(
402
+ "Mismatch: concatenated_dataset has samples but all_sample_weights is empty."
403
+ )
404
+ logging.warning(
405
+ "Warning: All datasets are effectively empty or have zero weight. Sample weights list is empty."
406
+ )
407
+ return None # No samples to weight
408
+
409
+ return torch.DoubleTensor(all_sample_weights)
410
+
411
+ def get_dataloader(self) -> DataLoader:
412
+ """Create and return a PyTorch DataLoader with weighted sampling.
413
+
414
+ Uses HierarchicalSampler to first sample a dataset according to weights,
415
+ then uniformly sample within that dataset.
416
+
417
+ Returns:
418
+ DataLoader configured for weighted hierarchical sampling.
419
+
420
+ Raises:
421
+ ValueError: If no non-empty dataset has a positive sampling weight.
422
+ """
423
+ if len(self.concatenated_dataset) == 0:
424
+ logging.warning("Warning: Concatenated dataset is empty. DataLoader will produce no batches.")
425
+ # Return an empty dataloader or raise error, depending on desired behavior.
426
+ # For now, let it create an empty dataloader.
427
+ return DataLoader(
428
+ self.concatenated_dataset, batch_size=self.cfg.batch_size, num_workers=self.cfg.num_workers
429
+ )
430
+
431
+ # Validate there is at least one non-empty dataset with positive weight
432
+ if not any(len(ds) > 0 and w > 0 for ds, w in zip(self.datasets, self.dataset_weights, strict=True)):
433
+ logging.error("Error: No non-empty dataset has a positive sampling weight.")
434
+ raise ValueError("No non-empty dataset has a positive sampling weight.")
435
+
436
+ num_samples_per_epoch = len(self.concatenated_dataset)
437
+ logging.info("\nCreating DataLoader...")
438
+ logging.info(f" Batch size: {self.cfg.batch_size}")
439
+ logging.info(f" Samples per epoch (num_samples for sampler): {num_samples_per_epoch}")
440
+
441
+ # Hierarchical sampling: choose dataset by weight, then uniform within it (both with replacement)
442
+ ds_lengths = [len(ds) for ds in self.datasets]
443
+ sampler = HierarchicalSampler(
444
+ dataset_lengths=ds_lengths,
445
+ dataset_probs=self.dataset_weights,
446
+ num_samples=num_samples_per_epoch,
447
+ )
448
+
449
+ dataloader = DataLoader(
450
+ self.concatenated_dataset,
451
+ batch_size=self.cfg.dataloader_batch_size,
452
+ sampler=sampler,
453
+ num_workers=self.cfg.num_workers,
454
+ pin_memory=torch.cuda.is_available(),
455
+ drop_last=False,
456
+ prefetch_factor=self.cfg.prefetch_factor,
457
+ )
458
+ logging.info("DataLoader created successfully.")
459
+ logging.info("-" * 30)
460
+ return dataloader
@@ -0,0 +1,232 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Factory functions for creating datasets and dataset mixtures.
19
+
20
+ This module provides factory functions to create individual datasets and
21
+ weighted dataset mixtures from configuration objects. It handles the setup
22
+ of delta timestamps, image transforms, and metadata configuration before
23
+ instantiating datasets.
24
+
25
+ The factory supports two types of datasets:
26
+ 1. LeRobot datasets: Standard robot learning datasets loaded from HuggingFace
27
+ repositories with configurable delta timestamps for temporal alignment.
28
+ 2. Grounding datasets: Vision-language grounding datasets (CLEVR, COCO-QA,
29
+ PIXMO, VSR, etc.) for multimodal learning tasks.
30
+
31
+ Key Features:
32
+ - Delta timestamp resolution: Automatically configures temporal offsets
33
+ for features based on policy latency settings (action decoder and
34
+ cloud VLM latencies).
35
+ - Image transform support: Applies configurable image transformations
36
+ during dataset creation.
37
+ - Imagenet stats override: Optionally replaces dataset statistics with
38
+ ImageNet normalization statistics for camera features.
39
+ - Grounding dataset registration: Supports extensible grounding dataset
40
+ registration through side-effect imports.
41
+
42
+ Functions:
43
+ make_dataset: Creates a single dataset instance from a DatasetConfig,
44
+ handling delta timestamp setup, image transforms, and metadata
45
+ configuration.
46
+ make_dataset_mixture: Creates a WeightedDatasetMixture from a
47
+ TrainPipelineConfig containing multiple dataset configurations.
48
+ resolve_delta_timestamps: Resolves delta timestamps configuration based
49
+ on TrainPipelineConfig settings, mapping features to temporal groups.
50
+
51
+ Constants:
52
+ IMAGENET_STATS: ImageNet normalization statistics (mean, std, min, max)
53
+ used for camera feature normalization when use_imagenet_stats is enabled.
54
+
55
+ Example:
56
+ Create a single dataset:
57
+ >>> dataset = make_dataset(dataset_cfg, train_cfg, return_advantage_input=False)
58
+
59
+ Create a dataset mixture:
60
+ >>> mixture = make_dataset_mixture(train_cfg, return_advantage_input=False)
61
+ >>> dataloader = mixture.get_dataloader()
62
+ """
63
+
64
+ import numpy as np
65
+
66
+ # NOTE: Don't delete; imported for side effects.
67
+ import opentau.datasets.grounding.clevr # noqa: F401
68
+ import opentau.datasets.grounding.cocoqa # noqa: F401
69
+ import opentau.datasets.grounding.dummy # noqa: F401
70
+ import opentau.datasets.grounding.pixmo # noqa: F401
71
+ import opentau.datasets.grounding.vsr # noqa: F401
72
+ from opentau import available_grounding_datasets
73
+ from opentau.configs.default import DatasetConfig
74
+ from opentau.configs.train import TrainPipelineConfig
75
+ from opentau.datasets.dataset_mixture import WeightedDatasetMixture
76
+ from opentau.datasets.lerobot_dataset import (
77
+ BaseDataset,
78
+ LeRobotDataset,
79
+ LeRobotDatasetMetadata,
80
+ )
81
+ from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING
82
+ from opentau.datasets.transforms import ImageTransforms
83
+
84
+ IMAGENET_STATS = {
85
+ "min": [[[0.0]], [[0.0]], [[0.0]]], # (c,1,1)
86
+ "max": [[[1.0]], [[1.0]], [[1.0]]], # (c,1,1)
87
+ "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
88
+ "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
89
+ }
90
+
91
+
92
+ def resolve_delta_timestamps(
93
+ cfg: TrainPipelineConfig, dataset_cfg: DatasetConfig, ds_meta: LeRobotDatasetMetadata
94
+ ) -> tuple:
95
+ """Resolves delta_timestamps by based on TrainPipelineConfig.
96
+
97
+ Args:
98
+ cfg (TrainPipelineConfig): The TrainPipelineConfig to read delta_indices from.
99
+ ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
100
+ delta_timestamps against.
101
+
102
+ Returns:
103
+ A 2-tuple containing:
104
+
105
+ - At index 0, a 4-tuple containing delta timestamps mean, std, lower, and upper bounds for each group.
106
+ - At index 1, a dictionary mapping feature names to their corresponding group and index.
107
+
108
+ The delta timestamps and group mapping should follow the structure expected by LeRobotDataset.
109
+ """
110
+ group = "input_group"
111
+ feature2group = {}
112
+ # Delta timestamps are in seconds, and negative because they represent past timestamps.
113
+ # Hence, lower and upper bounds correspond to -upper and -lower.
114
+ delta_timestamps = {group: [-cfg.policy.action_decoder_latency_mean, -cfg.policy.cloud_vlm_latency_mean]}
115
+ delta_timestamps_std = {group: [cfg.policy.action_decoder_latency_std, cfg.policy.cloud_vlm_latency_std]}
116
+ delta_timestamps_lower = {
117
+ group: [-cfg.policy.action_decoder_latency_upper, -cfg.policy.cloud_vlm_latency_upper]
118
+ }
119
+ delta_timestamps_upper = {
120
+ group: [-cfg.policy.action_decoder_latency_lower, -cfg.policy.cloud_vlm_latency_lower]
121
+ }
122
+ action_freq = cfg.dataset_mixture.action_freq
123
+
124
+ name_map = DATA_FEATURES_NAME_MAPPING[dataset_cfg.repo_id]
125
+ reverse_name_map = {v: k for k, v in name_map.items()}
126
+ for key in ds_meta.features:
127
+ if key not in reverse_name_map:
128
+ continue # only process camera, state, and action features
129
+
130
+ standard_key = reverse_name_map[key]
131
+ if standard_key == "actions" and cfg.policy.action_delta_indices is not None:
132
+ delta_timestamps[key] = [i / action_freq for i in cfg.policy.action_delta_indices]
133
+ feature2group[key] = (key, None)
134
+ if "camera" in standard_key:
135
+ # Index 0 corresponds to action decoder latency and index 1 to cloud VLM latency.
136
+ # Pick both indices. `_to_standard_data_format()` will separate the two.
137
+ feature2group[key] = (group, [0, 1])
138
+ elif standard_key == "state":
139
+ # Pick index 0, which corresponds to latency of action decoder, and squeeze it to a scalar.
140
+ feature2group[key] = (group, 0)
141
+
142
+ return (
143
+ delta_timestamps,
144
+ delta_timestamps_std,
145
+ delta_timestamps_lower,
146
+ delta_timestamps_upper,
147
+ ), feature2group
148
+
149
+
150
+ def make_dataset(
151
+ cfg: DatasetConfig,
152
+ train_cfg: TrainPipelineConfig,
153
+ return_advantage_input: bool = False,
154
+ ) -> BaseDataset:
155
+ """Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
156
+
157
+ Args:
158
+ cfg (DatasetConfig): A DatasetConfig used to create a LeRobotDataset.
159
+ train_cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
160
+ return_advantage_input (bool): Whether the created dataset includes advantage inputs including "success",
161
+ "episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
162
+
163
+ Raises:
164
+ NotImplementedError: The MultiLeRobotDataset is currently deactivated.
165
+
166
+ Returns:
167
+ BaseDataset
168
+ """
169
+ image_transforms = ImageTransforms(cfg.image_transforms) if cfg.image_transforms.enable else None
170
+
171
+ if isinstance(cfg.grounding, str) + isinstance(cfg.repo_id, str) != 1:
172
+ raise ValueError("Exactly one of `cfg.grounding` and `cfg.repo_id` should be provided.")
173
+
174
+ if isinstance(cfg.grounding, str):
175
+ ds_cls = available_grounding_datasets.get(cfg.grounding)
176
+ if ds_cls is None:
177
+ raise ValueError(
178
+ f"Unknown grounding dataset '{cfg.grounding}'. "
179
+ f"Supported datasets are: {available_grounding_datasets.keys()}"
180
+ )
181
+ # TODO support dataset-specific arg / kwargs
182
+ dataset = ds_cls(train_cfg)
183
+ elif isinstance(cfg.repo_id, str):
184
+ ds_meta = LeRobotDatasetMetadata(cfg.repo_id, root=cfg.root, revision=cfg.revision)
185
+ (dt_mean, dt_std, dt_lower, dt_upper), f2g = resolve_delta_timestamps(train_cfg, cfg, ds_meta)
186
+ dataset = LeRobotDataset(
187
+ train_cfg,
188
+ cfg.repo_id,
189
+ root=cfg.root,
190
+ episodes=cfg.episodes,
191
+ delta_timestamps=dt_mean,
192
+ delta_timestamps_std=dt_std,
193
+ delta_timestamps_lower=dt_lower,
194
+ delta_timestamps_upper=dt_upper,
195
+ feature2group=f2g,
196
+ image_transforms=image_transforms,
197
+ revision=cfg.revision,
198
+ video_backend=cfg.video_backend,
199
+ image_resample_strategy=train_cfg.dataset_mixture.image_resample_strategy,
200
+ vector_resample_strategy=train_cfg.dataset_mixture.vector_resample_strategy,
201
+ return_advantage_input=return_advantage_input,
202
+ )
203
+
204
+ # TODO grounding datasets implement stats in original feature names, but camera_keys are standardized names
205
+ if not isinstance(cfg.grounding, str) and "dummy" not in cfg.repo_id and cfg.use_imagenet_stats:
206
+ for key in dataset.meta.camera_keys:
207
+ for stats_type, stats in IMAGENET_STATS.items():
208
+ if key not in dataset.meta.stats:
209
+ dataset.meta.stats[key] = {}
210
+ dataset.meta.stats[key][stats_type] = np.array(stats, dtype=np.float32)
211
+
212
+ return dataset
213
+
214
+
215
+ def make_dataset_mixture(
216
+ cfg: TrainPipelineConfig, return_advantage_input: bool = False
217
+ ) -> WeightedDatasetMixture:
218
+ """Creates a dataset mixture from the provided TrainPipelineConfig.
219
+
220
+ Args:
221
+ cfg (TrainPipelineConfig): The configuration containing the datasets to mix.
222
+ return_advantage_input (bool): Whether the datasets should return advantage inputs including "success",
223
+ "episode_end_idx", "current_idx", "last_step", "episode_index", and "timestamp". Defaults to False.
224
+
225
+ Returns:
226
+ WeightedDatasetMixture: An instance of WeightedDatasetMixture containing the datasets.
227
+ """
228
+ datasets = [
229
+ make_dataset(dataset_cfg, cfg, return_advantage_input=return_advantage_input)
230
+ for dataset_cfg in cfg.dataset_mixture.datasets
231
+ ]
232
+ return WeightedDatasetMixture(cfg, datasets, cfg.dataset_mixture.weights, cfg.dataset_mixture.action_freq)