opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""An online buffer for the online training loop in train.py
|
|
18
|
+
|
|
19
|
+
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
|
|
20
|
+
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
|
|
21
|
+
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
|
|
22
|
+
supports in-place slicing and mutation which is very handy for a dynamic buffer.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import os
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
import torch
|
|
31
|
+
|
|
32
|
+
from opentau.datasets.lerobot_dataset import LeRobotDataset
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _make_memmap_safe(**kwargs) -> np.memmap:
|
|
36
|
+
"""Create a numpy memmap with checks on available disk space.
|
|
37
|
+
|
|
38
|
+
Validates that sufficient disk space is available before creating the memmap
|
|
39
|
+
file in write mode.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
**kwargs: Keyword arguments for np.memmap including:
|
|
43
|
+
- filename: Path to the memmap file.
|
|
44
|
+
- dtype: numpy dtype (must be np.dtype).
|
|
45
|
+
- mode: File mode ('r+', 'w+', etc.).
|
|
46
|
+
- shape: Shape of the array.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
numpy memmap array.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
RuntimeError: If required disk space exceeds 80% of available space.
|
|
53
|
+
|
|
54
|
+
Note:
|
|
55
|
+
For information on dtypes, see:
|
|
56
|
+
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
|
|
57
|
+
"""
|
|
58
|
+
if kwargs["mode"].startswith("w"):
|
|
59
|
+
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
|
|
60
|
+
stats = os.statvfs(Path(kwargs["filename"]).parent)
|
|
61
|
+
available_space = stats.f_bavail * stats.f_frsize # bytes
|
|
62
|
+
if required_space >= available_space * 0.8:
|
|
63
|
+
raise RuntimeError(
|
|
64
|
+
f"You're about to take up {required_space} of {available_space} bytes available."
|
|
65
|
+
)
|
|
66
|
+
return np.memmap(**kwargs)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class OnlineBuffer(torch.utils.data.Dataset):
|
|
70
|
+
"""FIFO data buffer for the online training loop in train.py.
|
|
71
|
+
|
|
72
|
+
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
|
|
73
|
+
loop in the same way that a LeRobotDataset would be used.
|
|
74
|
+
|
|
75
|
+
The underlying data structure will have data inserted in a circular fashion. Always insert after the
|
|
76
|
+
last index, and when you reach the end, wrap around to the start.
|
|
77
|
+
|
|
78
|
+
The data is stored in a numpy memmap.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
NEXT_INDEX_KEY = "_next_index"
|
|
82
|
+
OCCUPANCY_MASK_KEY = "_occupancy_mask"
|
|
83
|
+
INDEX_KEY = "index"
|
|
84
|
+
FRAME_INDEX_KEY = "frame_index"
|
|
85
|
+
EPISODE_INDEX_KEY = "episode_index"
|
|
86
|
+
TIMESTAMP_KEY = "timestamp"
|
|
87
|
+
IS_PAD_POSTFIX = "_is_pad"
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
write_dir: str | Path,
|
|
92
|
+
data_spec: dict[str, Any] | None,
|
|
93
|
+
buffer_capacity: int | None,
|
|
94
|
+
fps: float | None = None,
|
|
95
|
+
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
|
|
96
|
+
):
|
|
97
|
+
"""
|
|
98
|
+
The online buffer can be provided from scratch or you can load an existing online buffer by passing
|
|
99
|
+
a `write_dir` associated with an existing buffer.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
|
|
103
|
+
Note that if the files already exist, they are opened in read-write mode (used for training
|
|
104
|
+
resumption.)
|
|
105
|
+
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
|
|
106
|
+
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
|
|
107
|
+
but note that "index", "frame_index" and "episode_index" are already accounted for by this
|
|
108
|
+
class, so you don't need to include them.
|
|
109
|
+
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
|
|
110
|
+
system's available disk space when choosing this.
|
|
111
|
+
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
|
|
112
|
+
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
|
|
113
|
+
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
|
|
114
|
+
converted to dict[str, np.ndarray] for optimization purposes.
|
|
115
|
+
|
|
116
|
+
"""
|
|
117
|
+
self.set_delta_timestamps(delta_timestamps)
|
|
118
|
+
self._fps = fps
|
|
119
|
+
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
|
|
120
|
+
# the requested frames. It is only used when `delta_timestamps` is provided.
|
|
121
|
+
# minus 1e-4 to account for possible numerical error
|
|
122
|
+
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
|
|
123
|
+
self._buffer_capacity = buffer_capacity
|
|
124
|
+
data_spec = self._make_data_spec(data_spec, buffer_capacity)
|
|
125
|
+
Path(write_dir).mkdir(parents=True, exist_ok=True)
|
|
126
|
+
self._data = {}
|
|
127
|
+
for k, v in data_spec.items():
|
|
128
|
+
self._data[k] = _make_memmap_safe(
|
|
129
|
+
filename=Path(write_dir) / k,
|
|
130
|
+
dtype=v["dtype"] if v is not None else None,
|
|
131
|
+
mode="r+" if (Path(write_dir) / k).exists() else "w+",
|
|
132
|
+
shape=tuple(v["shape"]) if v is not None else None,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
|
|
137
|
+
return self._delta_timestamps
|
|
138
|
+
|
|
139
|
+
def set_delta_timestamps(self, value: dict[str, list[float]] | None) -> None:
|
|
140
|
+
"""Set delta_timestamps converting the values to numpy arrays.
|
|
141
|
+
|
|
142
|
+
The conversion is for an optimization in the __getitem__. The loop is much
|
|
143
|
+
slower if the arrays need to be converted into numpy arrays on each access.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
value: Dictionary mapping feature names to lists of delta timestamps,
|
|
147
|
+
or None to disable delta timestamps.
|
|
148
|
+
"""
|
|
149
|
+
if value is not None:
|
|
150
|
+
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
|
|
151
|
+
else:
|
|
152
|
+
self._delta_timestamps = None
|
|
153
|
+
|
|
154
|
+
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
|
155
|
+
"""Create the complete data specification for numpy memmap files.
|
|
156
|
+
|
|
157
|
+
Adds internal keys (_next_index, _occupancy_mask) and standard keys
|
|
158
|
+
(index, frame_index, episode_index, timestamp) to the user-provided spec.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
data_spec: User-provided data specification dictionary.
|
|
162
|
+
buffer_capacity: Maximum number of frames the buffer can hold.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Complete data specification including internal and standard keys.
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
ValueError: If data_spec contains keys starting with '_' or contains
|
|
169
|
+
any of the preset keys (index, frame_index, episode_index, timestamp).
|
|
170
|
+
"""
|
|
171
|
+
if any(k.startswith("_") for k in data_spec):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
|
|
174
|
+
)
|
|
175
|
+
preset_keys = {
|
|
176
|
+
OnlineBuffer.INDEX_KEY,
|
|
177
|
+
OnlineBuffer.FRAME_INDEX_KEY,
|
|
178
|
+
OnlineBuffer.EPISODE_INDEX_KEY,
|
|
179
|
+
OnlineBuffer.TIMESTAMP_KEY,
|
|
180
|
+
}
|
|
181
|
+
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
|
|
184
|
+
f"The provided data_spec has {intersection}."
|
|
185
|
+
)
|
|
186
|
+
complete_data_spec = {
|
|
187
|
+
# _next_index will be a pointer to the next index that we should start filling from when we add
|
|
188
|
+
# more data.
|
|
189
|
+
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
|
190
|
+
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
|
191
|
+
# with real data rather than the dummy initialization.
|
|
192
|
+
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
|
193
|
+
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
|
194
|
+
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
|
195
|
+
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
|
196
|
+
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
|
197
|
+
}
|
|
198
|
+
for k, v in data_spec.items():
|
|
199
|
+
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
|
200
|
+
return complete_data_spec
|
|
201
|
+
|
|
202
|
+
def add_data(self, data: dict[str, np.ndarray]) -> None:
|
|
203
|
+
"""Add new data to the buffer, potentially overwriting old data in a circular fashion.
|
|
204
|
+
|
|
205
|
+
The new data should contain all frames (in order) of any number of episodes.
|
|
206
|
+
Indices should start from 0. This method shifts incoming data indices and
|
|
207
|
+
episode indices to continue from the last frame in the buffer.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
data: Dictionary mapping data keys to numpy arrays. All arrays must
|
|
211
|
+
have the same length. Must include all keys from data_keys.
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
ValueError: If data is missing required keys or arrays have different lengths.
|
|
215
|
+
|
|
216
|
+
Note:
|
|
217
|
+
This modifies the input data in place by shifting indices.
|
|
218
|
+
See `rollout` and `eval_policy` functions in `eval.py` for usage examples.
|
|
219
|
+
"""
|
|
220
|
+
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
|
|
221
|
+
raise ValueError(f"Missing data keys: {missing_keys}")
|
|
222
|
+
new_data_length = len(data[self.data_keys[0]])
|
|
223
|
+
if not all(len(data[k]) == new_data_length for k in self.data_keys):
|
|
224
|
+
raise ValueError("All data items should have the same length")
|
|
225
|
+
|
|
226
|
+
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
|
|
227
|
+
|
|
228
|
+
# Sanity check to make sure that the new data indices start from 0.
|
|
229
|
+
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
|
|
230
|
+
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
|
231
|
+
|
|
232
|
+
# Shift the incoming indices if necessary.
|
|
233
|
+
if self.num_frames > 0:
|
|
234
|
+
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
|
235
|
+
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
|
236
|
+
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
|
237
|
+
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
|
238
|
+
|
|
239
|
+
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
|
|
240
|
+
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
|
|
241
|
+
for k in self.data_keys:
|
|
242
|
+
if n_surplus == 0:
|
|
243
|
+
slc = slice(next_index, next_index + new_data_length)
|
|
244
|
+
self._data[k][slc] = data[k]
|
|
245
|
+
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
|
|
246
|
+
else:
|
|
247
|
+
self._data[k][next_index:] = data[k][:-n_surplus]
|
|
248
|
+
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
|
|
249
|
+
self._data[k][:n_surplus] = data[k][-n_surplus:]
|
|
250
|
+
if n_surplus == 0:
|
|
251
|
+
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
|
|
252
|
+
else:
|
|
253
|
+
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def data_keys(self) -> list[str]:
|
|
257
|
+
keys = set(self._data)
|
|
258
|
+
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
|
|
259
|
+
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
|
|
260
|
+
return sorted(keys)
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def fps(self) -> float | None:
|
|
264
|
+
return self._fps
|
|
265
|
+
|
|
266
|
+
@property
|
|
267
|
+
def num_episodes(self) -> int:
|
|
268
|
+
return len(
|
|
269
|
+
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def num_frames(self) -> int:
|
|
274
|
+
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
|
275
|
+
|
|
276
|
+
def __len__(self):
|
|
277
|
+
return self.num_frames
|
|
278
|
+
|
|
279
|
+
def _item_to_tensors(self, item: dict) -> dict:
|
|
280
|
+
"""Convert all values in an item dictionary to torch tensors.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
item: Dictionary with numpy arrays, torch tensors, or scalars.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Dictionary with all values converted to torch tensors.
|
|
287
|
+
"""
|
|
288
|
+
item_ = {}
|
|
289
|
+
for k, v in item.items():
|
|
290
|
+
if isinstance(v, torch.Tensor):
|
|
291
|
+
item_[k] = v
|
|
292
|
+
elif isinstance(v, np.ndarray):
|
|
293
|
+
item_[k] = torch.from_numpy(v)
|
|
294
|
+
else:
|
|
295
|
+
item_[k] = torch.tensor(v)
|
|
296
|
+
return item_
|
|
297
|
+
|
|
298
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
299
|
+
if idx >= len(self) or idx < -len(self):
|
|
300
|
+
raise IndexError
|
|
301
|
+
|
|
302
|
+
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
|
|
303
|
+
|
|
304
|
+
if self.delta_timestamps is None:
|
|
305
|
+
return self._item_to_tensors(item)
|
|
306
|
+
|
|
307
|
+
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
|
|
308
|
+
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
|
|
309
|
+
episode_data_indices = np.where(
|
|
310
|
+
np.bitwise_and(
|
|
311
|
+
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
|
|
312
|
+
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
|
313
|
+
)
|
|
314
|
+
)[0]
|
|
315
|
+
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
|
316
|
+
|
|
317
|
+
for data_key in self.delta_timestamps:
|
|
318
|
+
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
|
319
|
+
# Get timestamps used as query to retrieve data of previous/future frames.
|
|
320
|
+
query_ts = current_ts + self.delta_timestamps[data_key]
|
|
321
|
+
|
|
322
|
+
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
|
|
323
|
+
# the episode.
|
|
324
|
+
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
|
|
325
|
+
argmin_ = np.argmin(dist, axis=1)
|
|
326
|
+
min_ = dist[np.arange(dist.shape[0]), argmin_]
|
|
327
|
+
|
|
328
|
+
is_pad = min_ > self.tolerance_s
|
|
329
|
+
|
|
330
|
+
# Check violated query timestamps are all outside the episode range.
|
|
331
|
+
assert (
|
|
332
|
+
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
|
333
|
+
).all(), (
|
|
334
|
+
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
|
335
|
+
") inside the episode range."
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Load frames for this data key.
|
|
339
|
+
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
|
|
340
|
+
|
|
341
|
+
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
|
|
342
|
+
|
|
343
|
+
return self._item_to_tensors(item)
|
|
344
|
+
|
|
345
|
+
def get_data_by_key(self, key: str) -> torch.Tensor:
|
|
346
|
+
"""Get all occupied data for a given key as a torch tensor.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
key: Data key to retrieve.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Tensor containing all non-padded data for the specified key.
|
|
353
|
+
"""
|
|
354
|
+
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def compute_sampler_weights(
|
|
358
|
+
offline_dataset: LeRobotDataset,
|
|
359
|
+
offline_drop_n_last_frames: int = 0,
|
|
360
|
+
online_dataset: OnlineBuffer | None = None,
|
|
361
|
+
online_sampling_ratio: float | None = None,
|
|
362
|
+
online_drop_n_last_frames: int = 0,
|
|
363
|
+
) -> torch.Tensor:
|
|
364
|
+
"""Compute the sampling weights for the online training dataloader in train.py.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
offline_dataset: The LeRobotDataset used for offline pre-training.
|
|
368
|
+
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
|
|
369
|
+
online_dataset: The OnlineBuffer used in online training.
|
|
370
|
+
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
|
|
371
|
+
online dataset is provided, this value must also be provided.
|
|
372
|
+
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
|
|
373
|
+
dataset.
|
|
374
|
+
Returns:
|
|
375
|
+
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
|
|
376
|
+
|
|
377
|
+
Notes to maintainers:
|
|
378
|
+
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
|
|
379
|
+
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
|
|
380
|
+
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
|
|
381
|
+
is the ability to turn shuffling off.
|
|
382
|
+
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
|
383
|
+
included here to avoid adding complexity.
|
|
384
|
+
"""
|
|
385
|
+
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
|
386
|
+
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
|
387
|
+
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
|
388
|
+
raise ValueError(
|
|
389
|
+
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
|
390
|
+
)
|
|
391
|
+
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
|
392
|
+
|
|
393
|
+
weights = []
|
|
394
|
+
|
|
395
|
+
if len(offline_dataset) > 0:
|
|
396
|
+
offline_data_mask_indices = []
|
|
397
|
+
for start_index, end_index in zip(
|
|
398
|
+
offline_dataset.episode_data_index["from"],
|
|
399
|
+
offline_dataset.episode_data_index["to"],
|
|
400
|
+
strict=True,
|
|
401
|
+
):
|
|
402
|
+
offline_data_mask_indices.extend(
|
|
403
|
+
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
|
404
|
+
)
|
|
405
|
+
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
|
406
|
+
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
|
407
|
+
weights.append(
|
|
408
|
+
torch.full(
|
|
409
|
+
size=(len(offline_dataset),),
|
|
410
|
+
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
|
|
411
|
+
)
|
|
412
|
+
* offline_data_mask
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
if online_dataset is not None and len(online_dataset) > 0:
|
|
416
|
+
online_data_mask_indices = []
|
|
417
|
+
episode_indices = online_dataset.get_data_by_key("episode_index")
|
|
418
|
+
for episode_idx in torch.unique(episode_indices):
|
|
419
|
+
where_episode = torch.where(episode_indices == episode_idx)
|
|
420
|
+
start_index = where_episode[0][0]
|
|
421
|
+
end_index = where_episode[0][-1] + 1
|
|
422
|
+
online_data_mask_indices.extend(
|
|
423
|
+
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
|
|
424
|
+
)
|
|
425
|
+
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
|
|
426
|
+
online_data_mask[torch.tensor(online_data_mask_indices)] = True
|
|
427
|
+
weights.append(
|
|
428
|
+
torch.full(
|
|
429
|
+
size=(len(online_dataset),),
|
|
430
|
+
fill_value=online_sampling_ratio / online_data_mask.sum(),
|
|
431
|
+
)
|
|
432
|
+
* online_data_mask
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
weights = torch.cat(weights)
|
|
436
|
+
|
|
437
|
+
if weights.sum() == 0:
|
|
438
|
+
weights += 1 / len(weights)
|
|
439
|
+
else:
|
|
440
|
+
weights /= weights.sum()
|
|
441
|
+
|
|
442
|
+
return weights
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
import inspect
|
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Dict
|
|
21
|
+
|
|
22
|
+
import datasets
|
|
23
|
+
import numpy
|
|
24
|
+
import PIL
|
|
25
|
+
import torch
|
|
26
|
+
|
|
27
|
+
from opentau.datasets.video_utils import encode_video_frames
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def concatenate_episodes(ep_dicts):
|
|
31
|
+
data_dict = {}
|
|
32
|
+
|
|
33
|
+
keys = ep_dicts[0].keys()
|
|
34
|
+
for key in keys:
|
|
35
|
+
if torch.is_tensor(ep_dicts[0][key][0]):
|
|
36
|
+
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
|
37
|
+
else:
|
|
38
|
+
if key not in data_dict:
|
|
39
|
+
data_dict[key] = []
|
|
40
|
+
for ep_dict in ep_dicts:
|
|
41
|
+
for x in ep_dict[key]:
|
|
42
|
+
data_dict[key].append(x)
|
|
43
|
+
|
|
44
|
+
total_frames = data_dict["frame_index"].shape[0]
|
|
45
|
+
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
46
|
+
return data_dict
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
|
50
|
+
out_dir = Path(out_dir)
|
|
51
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
def save_image(img_array, i, out_dir):
|
|
54
|
+
img = PIL.Image.fromarray(img_array)
|
|
55
|
+
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
|
|
56
|
+
|
|
57
|
+
num_images = len(imgs_array)
|
|
58
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
59
|
+
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_default_encoding() -> dict:
|
|
63
|
+
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
|
|
64
|
+
signature = inspect.signature(encode_video_frames)
|
|
65
|
+
return {
|
|
66
|
+
k: v.default
|
|
67
|
+
for k, v in signature.parameters.items()
|
|
68
|
+
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def check_repo_id(repo_id: str) -> None:
|
|
73
|
+
if len(repo_id.split("/")) != 2:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
|
76
|
+
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# TODO(aliberts): remove
|
|
81
|
+
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
|
82
|
+
"""
|
|
83
|
+
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
|
84
|
+
|
|
85
|
+
Parameters:
|
|
86
|
+
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
|
90
|
+
- "from": A tensor containing the starting index of each episode.
|
|
91
|
+
- "to": A tensor containing the ending index of each episode.
|
|
92
|
+
"""
|
|
93
|
+
episode_data_index = {"from": [], "to": []}
|
|
94
|
+
|
|
95
|
+
current_episode = None
|
|
96
|
+
"""
|
|
97
|
+
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
|
98
|
+
For instance, the following is a valid episode_index:
|
|
99
|
+
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
|
100
|
+
|
|
101
|
+
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
|
102
|
+
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
|
103
|
+
{
|
|
104
|
+
"from": [0, 3, 7],
|
|
105
|
+
"to": [3, 7, 12]
|
|
106
|
+
}
|
|
107
|
+
"""
|
|
108
|
+
if len(hf_dataset) == 0:
|
|
109
|
+
episode_data_index = {
|
|
110
|
+
"from": torch.tensor([]),
|
|
111
|
+
"to": torch.tensor([]),
|
|
112
|
+
}
|
|
113
|
+
return episode_data_index
|
|
114
|
+
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
|
115
|
+
if episode_idx != current_episode:
|
|
116
|
+
# We encountered a new episode, so we append its starting location to the "from" list
|
|
117
|
+
episode_data_index["from"].append(idx)
|
|
118
|
+
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
|
119
|
+
if current_episode is not None:
|
|
120
|
+
episode_data_index["to"].append(idx)
|
|
121
|
+
# Let's keep track of the current episode index
|
|
122
|
+
current_episode = episode_idx
|
|
123
|
+
else:
|
|
124
|
+
# We are still in the same episode, so there is nothing for us to do here
|
|
125
|
+
pass
|
|
126
|
+
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
|
127
|
+
episode_data_index["to"].append(idx + 1)
|
|
128
|
+
|
|
129
|
+
for k in ["from", "to"]:
|
|
130
|
+
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
|
131
|
+
|
|
132
|
+
return episode_data_index
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Episode-aware sampler for PyTorch DataLoader.
|
|
18
|
+
|
|
19
|
+
This module provides a sampler that respects episode boundaries in robot
|
|
20
|
+
learning datasets. It allows filtering specific episodes, dropping frames
|
|
21
|
+
from episode boundaries, and optional shuffling while maintaining episode
|
|
22
|
+
structure awareness.
|
|
23
|
+
|
|
24
|
+
The sampler is designed for use with PyTorch DataLoader to ensure proper
|
|
25
|
+
sampling behavior when working with sequential episode data, where episode
|
|
26
|
+
boundaries are important for maintaining temporal coherence or avoiding
|
|
27
|
+
invalid transitions.
|
|
28
|
+
|
|
29
|
+
Key Features:
|
|
30
|
+
- Episode filtering: Select specific episodes to include in sampling.
|
|
31
|
+
- Boundary frame dropping: Optionally drop frames from the start or end
|
|
32
|
+
of episodes (useful for avoiding invalid transitions or edge cases).
|
|
33
|
+
- Optional shuffling: Shuffle indices while maintaining episode awareness.
|
|
34
|
+
- PyTorch compatible: Implements the standard Sampler interface for use
|
|
35
|
+
with DataLoader.
|
|
36
|
+
|
|
37
|
+
Classes:
|
|
38
|
+
EpisodeAwareSampler: PyTorch-style sampler that respects episode
|
|
39
|
+
boundaries, supports episode filtering, frame dropping, and shuffling.
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
Create a sampler for specific episodes:
|
|
43
|
+
>>> episode_data_index = {"from": [0, 100, 200], "to": [99, 199, 299]}
|
|
44
|
+
>>> sampler = EpisodeAwareSampler(
|
|
45
|
+
... episode_data_index,
|
|
46
|
+
... episode_indices_to_use=[0, 2], # Use episodes 0 and 2
|
|
47
|
+
... drop_n_first_frames=5,
|
|
48
|
+
... drop_n_last_frames=5,
|
|
49
|
+
... shuffle=True
|
|
50
|
+
... )
|
|
51
|
+
>>> dataloader = DataLoader(dataset, sampler=sampler)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
from typing import Iterator, Union
|
|
55
|
+
|
|
56
|
+
import torch
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class EpisodeAwareSampler:
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
episode_data_index: dict,
|
|
63
|
+
episode_indices_to_use: Union[list, None] = None,
|
|
64
|
+
drop_n_first_frames: int = 0,
|
|
65
|
+
drop_n_last_frames: int = 0,
|
|
66
|
+
shuffle: bool = False,
|
|
67
|
+
):
|
|
68
|
+
"""Sampler that optionally incorporates episode boundary information.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
|
72
|
+
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
|
73
|
+
Assumes that episodes are indexed from 0 to N-1.
|
|
74
|
+
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
|
75
|
+
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
|
76
|
+
shuffle: Whether to shuffle the indices.
|
|
77
|
+
"""
|
|
78
|
+
indices = []
|
|
79
|
+
for episode_idx, (start_index, end_index) in enumerate(
|
|
80
|
+
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
|
81
|
+
):
|
|
82
|
+
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
|
83
|
+
indices.extend(
|
|
84
|
+
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.indices = indices
|
|
88
|
+
self.shuffle = shuffle
|
|
89
|
+
|
|
90
|
+
def __iter__(self) -> Iterator[int]:
|
|
91
|
+
if self.shuffle:
|
|
92
|
+
for i in torch.randperm(len(self.indices)):
|
|
93
|
+
yield self.indices[i]
|
|
94
|
+
else:
|
|
95
|
+
for i in self.indices:
|
|
96
|
+
yield i
|
|
97
|
+
|
|
98
|
+
def __len__(self) -> int:
|
|
99
|
+
return len(self.indices)
|