opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,442 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """An online buffer for the online training loop in train.py
18
+
19
+ Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
20
+ consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
21
+ faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
22
+ supports in-place slicing and mutation which is very handy for a dynamic buffer.
23
+ """
24
+
25
+ import os
26
+ from pathlib import Path
27
+ from typing import Any
28
+
29
+ import numpy as np
30
+ import torch
31
+
32
+ from opentau.datasets.lerobot_dataset import LeRobotDataset
33
+
34
+
35
+ def _make_memmap_safe(**kwargs) -> np.memmap:
36
+ """Create a numpy memmap with checks on available disk space.
37
+
38
+ Validates that sufficient disk space is available before creating the memmap
39
+ file in write mode.
40
+
41
+ Args:
42
+ **kwargs: Keyword arguments for np.memmap including:
43
+ - filename: Path to the memmap file.
44
+ - dtype: numpy dtype (must be np.dtype).
45
+ - mode: File mode ('r+', 'w+', etc.).
46
+ - shape: Shape of the array.
47
+
48
+ Returns:
49
+ numpy memmap array.
50
+
51
+ Raises:
52
+ RuntimeError: If required disk space exceeds 80% of available space.
53
+
54
+ Note:
55
+ For information on dtypes, see:
56
+ https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
57
+ """
58
+ if kwargs["mode"].startswith("w"):
59
+ required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
60
+ stats = os.statvfs(Path(kwargs["filename"]).parent)
61
+ available_space = stats.f_bavail * stats.f_frsize # bytes
62
+ if required_space >= available_space * 0.8:
63
+ raise RuntimeError(
64
+ f"You're about to take up {required_space} of {available_space} bytes available."
65
+ )
66
+ return np.memmap(**kwargs)
67
+
68
+
69
+ class OnlineBuffer(torch.utils.data.Dataset):
70
+ """FIFO data buffer for the online training loop in train.py.
71
+
72
+ Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
73
+ loop in the same way that a LeRobotDataset would be used.
74
+
75
+ The underlying data structure will have data inserted in a circular fashion. Always insert after the
76
+ last index, and when you reach the end, wrap around to the start.
77
+
78
+ The data is stored in a numpy memmap.
79
+ """
80
+
81
+ NEXT_INDEX_KEY = "_next_index"
82
+ OCCUPANCY_MASK_KEY = "_occupancy_mask"
83
+ INDEX_KEY = "index"
84
+ FRAME_INDEX_KEY = "frame_index"
85
+ EPISODE_INDEX_KEY = "episode_index"
86
+ TIMESTAMP_KEY = "timestamp"
87
+ IS_PAD_POSTFIX = "_is_pad"
88
+
89
+ def __init__(
90
+ self,
91
+ write_dir: str | Path,
92
+ data_spec: dict[str, Any] | None,
93
+ buffer_capacity: int | None,
94
+ fps: float | None = None,
95
+ delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
96
+ ):
97
+ """
98
+ The online buffer can be provided from scratch or you can load an existing online buffer by passing
99
+ a `write_dir` associated with an existing buffer.
100
+
101
+ Args:
102
+ write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
103
+ Note that if the files already exist, they are opened in read-write mode (used for training
104
+ resumption.)
105
+ data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
106
+ "dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
107
+ but note that "index", "frame_index" and "episode_index" are already accounted for by this
108
+ class, so you don't need to include them.
109
+ buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
110
+ system's available disk space when choosing this.
111
+ fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
112
+ delta_timestamps logic. You can pass None if you are not using delta_timestamps.
113
+ delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
114
+ converted to dict[str, np.ndarray] for optimization purposes.
115
+
116
+ """
117
+ self.set_delta_timestamps(delta_timestamps)
118
+ self._fps = fps
119
+ # Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
120
+ # the requested frames. It is only used when `delta_timestamps` is provided.
121
+ # minus 1e-4 to account for possible numerical error
122
+ self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
123
+ self._buffer_capacity = buffer_capacity
124
+ data_spec = self._make_data_spec(data_spec, buffer_capacity)
125
+ Path(write_dir).mkdir(parents=True, exist_ok=True)
126
+ self._data = {}
127
+ for k, v in data_spec.items():
128
+ self._data[k] = _make_memmap_safe(
129
+ filename=Path(write_dir) / k,
130
+ dtype=v["dtype"] if v is not None else None,
131
+ mode="r+" if (Path(write_dir) / k).exists() else "w+",
132
+ shape=tuple(v["shape"]) if v is not None else None,
133
+ )
134
+
135
+ @property
136
+ def delta_timestamps(self) -> dict[str, np.ndarray] | None:
137
+ return self._delta_timestamps
138
+
139
+ def set_delta_timestamps(self, value: dict[str, list[float]] | None) -> None:
140
+ """Set delta_timestamps converting the values to numpy arrays.
141
+
142
+ The conversion is for an optimization in the __getitem__. The loop is much
143
+ slower if the arrays need to be converted into numpy arrays on each access.
144
+
145
+ Args:
146
+ value: Dictionary mapping feature names to lists of delta timestamps,
147
+ or None to disable delta timestamps.
148
+ """
149
+ if value is not None:
150
+ self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
151
+ else:
152
+ self._delta_timestamps = None
153
+
154
+ def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
155
+ """Create the complete data specification for numpy memmap files.
156
+
157
+ Adds internal keys (_next_index, _occupancy_mask) and standard keys
158
+ (index, frame_index, episode_index, timestamp) to the user-provided spec.
159
+
160
+ Args:
161
+ data_spec: User-provided data specification dictionary.
162
+ buffer_capacity: Maximum number of frames the buffer can hold.
163
+
164
+ Returns:
165
+ Complete data specification including internal and standard keys.
166
+
167
+ Raises:
168
+ ValueError: If data_spec contains keys starting with '_' or contains
169
+ any of the preset keys (index, frame_index, episode_index, timestamp).
170
+ """
171
+ if any(k.startswith("_") for k in data_spec):
172
+ raise ValueError(
173
+ "data_spec keys should not start with '_'. This prefix is reserved for internal logic."
174
+ )
175
+ preset_keys = {
176
+ OnlineBuffer.INDEX_KEY,
177
+ OnlineBuffer.FRAME_INDEX_KEY,
178
+ OnlineBuffer.EPISODE_INDEX_KEY,
179
+ OnlineBuffer.TIMESTAMP_KEY,
180
+ }
181
+ if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
182
+ raise ValueError(
183
+ f"data_spec should not contain any of {preset_keys} as these are handled internally. "
184
+ f"The provided data_spec has {intersection}."
185
+ )
186
+ complete_data_spec = {
187
+ # _next_index will be a pointer to the next index that we should start filling from when we add
188
+ # more data.
189
+ OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
190
+ # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
191
+ # with real data rather than the dummy initialization.
192
+ OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
193
+ OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
194
+ OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
195
+ OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
196
+ OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
197
+ }
198
+ for k, v in data_spec.items():
199
+ complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
200
+ return complete_data_spec
201
+
202
+ def add_data(self, data: dict[str, np.ndarray]) -> None:
203
+ """Add new data to the buffer, potentially overwriting old data in a circular fashion.
204
+
205
+ The new data should contain all frames (in order) of any number of episodes.
206
+ Indices should start from 0. This method shifts incoming data indices and
207
+ episode indices to continue from the last frame in the buffer.
208
+
209
+ Args:
210
+ data: Dictionary mapping data keys to numpy arrays. All arrays must
211
+ have the same length. Must include all keys from data_keys.
212
+
213
+ Raises:
214
+ ValueError: If data is missing required keys or arrays have different lengths.
215
+
216
+ Note:
217
+ This modifies the input data in place by shifting indices.
218
+ See `rollout` and `eval_policy` functions in `eval.py` for usage examples.
219
+ """
220
+ if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
221
+ raise ValueError(f"Missing data keys: {missing_keys}")
222
+ new_data_length = len(data[self.data_keys[0]])
223
+ if not all(len(data[k]) == new_data_length for k in self.data_keys):
224
+ raise ValueError("All data items should have the same length")
225
+
226
+ next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
227
+
228
+ # Sanity check to make sure that the new data indices start from 0.
229
+ assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
230
+ assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
231
+
232
+ # Shift the incoming indices if necessary.
233
+ if self.num_frames > 0:
234
+ last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
235
+ last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
236
+ data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
237
+ data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
238
+
239
+ # Insert the new data starting from next_index. It may be necessary to wrap around to the start.
240
+ n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
241
+ for k in self.data_keys:
242
+ if n_surplus == 0:
243
+ slc = slice(next_index, next_index + new_data_length)
244
+ self._data[k][slc] = data[k]
245
+ self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
246
+ else:
247
+ self._data[k][next_index:] = data[k][:-n_surplus]
248
+ self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
249
+ self._data[k][:n_surplus] = data[k][-n_surplus:]
250
+ if n_surplus == 0:
251
+ self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
252
+ else:
253
+ self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
254
+
255
+ @property
256
+ def data_keys(self) -> list[str]:
257
+ keys = set(self._data)
258
+ keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
259
+ keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
260
+ return sorted(keys)
261
+
262
+ @property
263
+ def fps(self) -> float | None:
264
+ return self._fps
265
+
266
+ @property
267
+ def num_episodes(self) -> int:
268
+ return len(
269
+ np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
270
+ )
271
+
272
+ @property
273
+ def num_frames(self) -> int:
274
+ return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
275
+
276
+ def __len__(self):
277
+ return self.num_frames
278
+
279
+ def _item_to_tensors(self, item: dict) -> dict:
280
+ """Convert all values in an item dictionary to torch tensors.
281
+
282
+ Args:
283
+ item: Dictionary with numpy arrays, torch tensors, or scalars.
284
+
285
+ Returns:
286
+ Dictionary with all values converted to torch tensors.
287
+ """
288
+ item_ = {}
289
+ for k, v in item.items():
290
+ if isinstance(v, torch.Tensor):
291
+ item_[k] = v
292
+ elif isinstance(v, np.ndarray):
293
+ item_[k] = torch.from_numpy(v)
294
+ else:
295
+ item_[k] = torch.tensor(v)
296
+ return item_
297
+
298
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
299
+ if idx >= len(self) or idx < -len(self):
300
+ raise IndexError
301
+
302
+ item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
303
+
304
+ if self.delta_timestamps is None:
305
+ return self._item_to_tensors(item)
306
+
307
+ episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
308
+ current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
309
+ episode_data_indices = np.where(
310
+ np.bitwise_and(
311
+ self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
312
+ self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
313
+ )
314
+ )[0]
315
+ episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
316
+
317
+ for data_key in self.delta_timestamps:
318
+ # Note: The logic in this loop is copied from `load_previous_and_future_frames`.
319
+ # Get timestamps used as query to retrieve data of previous/future frames.
320
+ query_ts = current_ts + self.delta_timestamps[data_key]
321
+
322
+ # Compute distances between each query timestamp and all timestamps of all the frames belonging to
323
+ # the episode.
324
+ dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
325
+ argmin_ = np.argmin(dist, axis=1)
326
+ min_ = dist[np.arange(dist.shape[0]), argmin_]
327
+
328
+ is_pad = min_ > self.tolerance_s
329
+
330
+ # Check violated query timestamps are all outside the episode range.
331
+ assert (
332
+ (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
333
+ ).all(), (
334
+ f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
335
+ ") inside the episode range."
336
+ )
337
+
338
+ # Load frames for this data key.
339
+ item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
340
+
341
+ item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
342
+
343
+ return self._item_to_tensors(item)
344
+
345
+ def get_data_by_key(self, key: str) -> torch.Tensor:
346
+ """Get all occupied data for a given key as a torch tensor.
347
+
348
+ Args:
349
+ key: Data key to retrieve.
350
+
351
+ Returns:
352
+ Tensor containing all non-padded data for the specified key.
353
+ """
354
+ return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
355
+
356
+
357
+ def compute_sampler_weights(
358
+ offline_dataset: LeRobotDataset,
359
+ offline_drop_n_last_frames: int = 0,
360
+ online_dataset: OnlineBuffer | None = None,
361
+ online_sampling_ratio: float | None = None,
362
+ online_drop_n_last_frames: int = 0,
363
+ ) -> torch.Tensor:
364
+ """Compute the sampling weights for the online training dataloader in train.py.
365
+
366
+ Args:
367
+ offline_dataset: The LeRobotDataset used for offline pre-training.
368
+ online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
369
+ online_dataset: The OnlineBuffer used in online training.
370
+ online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
371
+ online dataset is provided, this value must also be provided.
372
+ online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
373
+ dataset.
374
+ Returns:
375
+ Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
376
+
377
+ Notes to maintainers:
378
+ - This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
379
+ - When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
380
+ `EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
381
+ is the ability to turn shuffling off.
382
+ - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
383
+ included here to avoid adding complexity.
384
+ """
385
+ if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
386
+ raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
387
+ if (online_dataset is None) ^ (online_sampling_ratio is None):
388
+ raise ValueError(
389
+ "`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
390
+ )
391
+ offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
392
+
393
+ weights = []
394
+
395
+ if len(offline_dataset) > 0:
396
+ offline_data_mask_indices = []
397
+ for start_index, end_index in zip(
398
+ offline_dataset.episode_data_index["from"],
399
+ offline_dataset.episode_data_index["to"],
400
+ strict=True,
401
+ ):
402
+ offline_data_mask_indices.extend(
403
+ range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
404
+ )
405
+ offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
406
+ offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
407
+ weights.append(
408
+ torch.full(
409
+ size=(len(offline_dataset),),
410
+ fill_value=offline_sampling_ratio / offline_data_mask.sum(),
411
+ )
412
+ * offline_data_mask
413
+ )
414
+
415
+ if online_dataset is not None and len(online_dataset) > 0:
416
+ online_data_mask_indices = []
417
+ episode_indices = online_dataset.get_data_by_key("episode_index")
418
+ for episode_idx in torch.unique(episode_indices):
419
+ where_episode = torch.where(episode_indices == episode_idx)
420
+ start_index = where_episode[0][0]
421
+ end_index = where_episode[0][-1] + 1
422
+ online_data_mask_indices.extend(
423
+ range(start_index.item(), end_index.item() - online_drop_n_last_frames)
424
+ )
425
+ online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
426
+ online_data_mask[torch.tensor(online_data_mask_indices)] = True
427
+ weights.append(
428
+ torch.full(
429
+ size=(len(online_dataset),),
430
+ fill_value=online_sampling_ratio / online_data_mask.sum(),
431
+ )
432
+ * online_data_mask
433
+ )
434
+
435
+ weights = torch.cat(weights)
436
+
437
+ if weights.sum() == 0:
438
+ weights += 1 / len(weights)
439
+ else:
440
+ weights /= weights.sum()
441
+
442
+ return weights
@@ -0,0 +1,132 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import inspect
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from pathlib import Path
20
+ from typing import Dict
21
+
22
+ import datasets
23
+ import numpy
24
+ import PIL
25
+ import torch
26
+
27
+ from opentau.datasets.video_utils import encode_video_frames
28
+
29
+
30
+ def concatenate_episodes(ep_dicts):
31
+ data_dict = {}
32
+
33
+ keys = ep_dicts[0].keys()
34
+ for key in keys:
35
+ if torch.is_tensor(ep_dicts[0][key][0]):
36
+ data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
37
+ else:
38
+ if key not in data_dict:
39
+ data_dict[key] = []
40
+ for ep_dict in ep_dicts:
41
+ for x in ep_dict[key]:
42
+ data_dict[key].append(x)
43
+
44
+ total_frames = data_dict["frame_index"].shape[0]
45
+ data_dict["index"] = torch.arange(0, total_frames, 1)
46
+ return data_dict
47
+
48
+
49
+ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
50
+ out_dir = Path(out_dir)
51
+ out_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ def save_image(img_array, i, out_dir):
54
+ img = PIL.Image.fromarray(img_array)
55
+ img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
56
+
57
+ num_images = len(imgs_array)
58
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
59
+ [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
60
+
61
+
62
+ def get_default_encoding() -> dict:
63
+ """Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
64
+ signature = inspect.signature(encode_video_frames)
65
+ return {
66
+ k: v.default
67
+ for k, v in signature.parameters.items()
68
+ if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
69
+ }
70
+
71
+
72
+ def check_repo_id(repo_id: str) -> None:
73
+ if len(repo_id.split("/")) != 2:
74
+ raise ValueError(
75
+ f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
76
+ (e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
77
+ )
78
+
79
+
80
+ # TODO(aliberts): remove
81
+ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
82
+ """
83
+ Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
84
+
85
+ Parameters:
86
+ - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
87
+
88
+ Returns:
89
+ - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
90
+ - "from": A tensor containing the starting index of each episode.
91
+ - "to": A tensor containing the ending index of each episode.
92
+ """
93
+ episode_data_index = {"from": [], "to": []}
94
+
95
+ current_episode = None
96
+ """
97
+ The episode_index is a list of integers, each representing the episode index of the corresponding example.
98
+ For instance, the following is a valid episode_index:
99
+ [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
100
+
101
+ Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
102
+ ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
103
+ {
104
+ "from": [0, 3, 7],
105
+ "to": [3, 7, 12]
106
+ }
107
+ """
108
+ if len(hf_dataset) == 0:
109
+ episode_data_index = {
110
+ "from": torch.tensor([]),
111
+ "to": torch.tensor([]),
112
+ }
113
+ return episode_data_index
114
+ for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
115
+ if episode_idx != current_episode:
116
+ # We encountered a new episode, so we append its starting location to the "from" list
117
+ episode_data_index["from"].append(idx)
118
+ # If this is not the first episode, we append the ending location of the previous episode to the "to" list
119
+ if current_episode is not None:
120
+ episode_data_index["to"].append(idx)
121
+ # Let's keep track of the current episode index
122
+ current_episode = episode_idx
123
+ else:
124
+ # We are still in the same episode, so there is nothing for us to do here
125
+ pass
126
+ # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
127
+ episode_data_index["to"].append(idx + 1)
128
+
129
+ for k in ["from", "to"]:
130
+ episode_data_index[k] = torch.tensor(episode_data_index[k])
131
+
132
+ return episode_data_index
@@ -0,0 +1,99 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Episode-aware sampler for PyTorch DataLoader.
18
+
19
+ This module provides a sampler that respects episode boundaries in robot
20
+ learning datasets. It allows filtering specific episodes, dropping frames
21
+ from episode boundaries, and optional shuffling while maintaining episode
22
+ structure awareness.
23
+
24
+ The sampler is designed for use with PyTorch DataLoader to ensure proper
25
+ sampling behavior when working with sequential episode data, where episode
26
+ boundaries are important for maintaining temporal coherence or avoiding
27
+ invalid transitions.
28
+
29
+ Key Features:
30
+ - Episode filtering: Select specific episodes to include in sampling.
31
+ - Boundary frame dropping: Optionally drop frames from the start or end
32
+ of episodes (useful for avoiding invalid transitions or edge cases).
33
+ - Optional shuffling: Shuffle indices while maintaining episode awareness.
34
+ - PyTorch compatible: Implements the standard Sampler interface for use
35
+ with DataLoader.
36
+
37
+ Classes:
38
+ EpisodeAwareSampler: PyTorch-style sampler that respects episode
39
+ boundaries, supports episode filtering, frame dropping, and shuffling.
40
+
41
+ Example:
42
+ Create a sampler for specific episodes:
43
+ >>> episode_data_index = {"from": [0, 100, 200], "to": [99, 199, 299]}
44
+ >>> sampler = EpisodeAwareSampler(
45
+ ... episode_data_index,
46
+ ... episode_indices_to_use=[0, 2], # Use episodes 0 and 2
47
+ ... drop_n_first_frames=5,
48
+ ... drop_n_last_frames=5,
49
+ ... shuffle=True
50
+ ... )
51
+ >>> dataloader = DataLoader(dataset, sampler=sampler)
52
+ """
53
+
54
+ from typing import Iterator, Union
55
+
56
+ import torch
57
+
58
+
59
+ class EpisodeAwareSampler:
60
+ def __init__(
61
+ self,
62
+ episode_data_index: dict,
63
+ episode_indices_to_use: Union[list, None] = None,
64
+ drop_n_first_frames: int = 0,
65
+ drop_n_last_frames: int = 0,
66
+ shuffle: bool = False,
67
+ ):
68
+ """Sampler that optionally incorporates episode boundary information.
69
+
70
+ Args:
71
+ episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
72
+ episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
73
+ Assumes that episodes are indexed from 0 to N-1.
74
+ drop_n_first_frames: Number of frames to drop from the start of each episode.
75
+ drop_n_last_frames: Number of frames to drop from the end of each episode.
76
+ shuffle: Whether to shuffle the indices.
77
+ """
78
+ indices = []
79
+ for episode_idx, (start_index, end_index) in enumerate(
80
+ zip(episode_data_index["from"], episode_data_index["to"], strict=True)
81
+ ):
82
+ if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
83
+ indices.extend(
84
+ range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
85
+ )
86
+
87
+ self.indices = indices
88
+ self.shuffle = shuffle
89
+
90
+ def __iter__(self) -> Iterator[int]:
91
+ if self.shuffle:
92
+ for i in torch.randperm(len(self.indices)):
93
+ yield self.indices[i]
94
+ else:
95
+ for i in self.indices:
96
+ yield i
97
+
98
+ def __len__(self) -> int:
99
+ return len(self.indices)