opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,597 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Video encoding, decoding, and information extraction utilities.
|
|
18
|
+
|
|
19
|
+
This module provides functionality for working with video files in robot learning
|
|
20
|
+
datasets, including frame extraction at specific timestamps, video encoding from
|
|
21
|
+
image sequences, and metadata extraction. It supports multiple video backends
|
|
22
|
+
for flexible deployment across different platforms.
|
|
23
|
+
|
|
24
|
+
The module handles the complexity of video codecs, including inter-frame compression
|
|
25
|
+
where frames are stored as differences relative to key frames. This requires
|
|
26
|
+
loading preceding key frames when accessing specific timestamps, which the module
|
|
27
|
+
handles automatically.
|
|
28
|
+
|
|
29
|
+
Key Features:
|
|
30
|
+
- Multiple backends: Supports torchcodec (when available), pyav, and
|
|
31
|
+
video_reader backends with automatic fallback.
|
|
32
|
+
- Timestamp-based frame extraction: Extracts frames at specific timestamps
|
|
33
|
+
with tolerance checking to ensure synchronization.
|
|
34
|
+
- Video encoding: Encodes image sequences to video files using ffmpeg with
|
|
35
|
+
configurable codecs and quality settings.
|
|
36
|
+
- Metadata extraction: Extracts video and audio stream information using
|
|
37
|
+
ffprobe.
|
|
38
|
+
- HuggingFace integration: Provides VideoFrame feature type for HuggingFace
|
|
39
|
+
datasets.
|
|
40
|
+
|
|
41
|
+
Classes:
|
|
42
|
+
|
|
43
|
+
VideoFrame
|
|
44
|
+
PyArrow-based feature type for HuggingFace datasets containing video
|
|
45
|
+
frames with path and timestamp information.
|
|
46
|
+
|
|
47
|
+
Functions:
|
|
48
|
+
|
|
49
|
+
Video decoding:
|
|
50
|
+
decode_video_frames
|
|
51
|
+
Main interface for decoding frames at timestamps with automatic backend selection.
|
|
52
|
+
decode_video_frames_torchcodec
|
|
53
|
+
Decode frames using torchcodec backend.
|
|
54
|
+
decode_video_frames_torchvision
|
|
55
|
+
Decode frames using torchvision backends (pyav or video_reader).
|
|
56
|
+
|
|
57
|
+
Video encoding:
|
|
58
|
+
encode_video_frames
|
|
59
|
+
Encode a sequence of PNG images into a video file using ffmpeg.
|
|
60
|
+
|
|
61
|
+
Video information:
|
|
62
|
+
get_video_info
|
|
63
|
+
Extract video stream metadata (fps, dimensions, codec).
|
|
64
|
+
get_audio_info
|
|
65
|
+
Extract audio stream metadata (channels, codec, bitrate).
|
|
66
|
+
get_video_pixel_channels
|
|
67
|
+
Determine pixel channels from pixel format.
|
|
68
|
+
get_image_pixel_channels
|
|
69
|
+
Determine pixel channels from PIL Image mode.
|
|
70
|
+
|
|
71
|
+
Backend management:
|
|
72
|
+
get_safe_default_codec
|
|
73
|
+
Get default codec backend with fallback logic.
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
Decode frames at specific timestamps:
|
|
77
|
+
>>> frames = decode_video_frames(
|
|
78
|
+
... video_path="videos/episode_0.mp4",
|
|
79
|
+
... timestamps=[0.1, 0.2, 0.3],
|
|
80
|
+
... tolerance_s=1e-4,
|
|
81
|
+
... backend="torchcodec"
|
|
82
|
+
... )
|
|
83
|
+
|
|
84
|
+
Encode images to video:
|
|
85
|
+
>>> encode_video_frames(
|
|
86
|
+
... imgs_dir="images/episode_0",
|
|
87
|
+
... video_path="videos/episode_0.mp4",
|
|
88
|
+
... fps=30,
|
|
89
|
+
... vcodec="libsvtav1"
|
|
90
|
+
... )
|
|
91
|
+
|
|
92
|
+
Get video information:
|
|
93
|
+
>>> info = get_video_info("videos/episode_0.mp4")
|
|
94
|
+
>>> print(f"FPS: {info['video.fps']}, Resolution: {info['video.width']}x{info['video.height']}")
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
import importlib
|
|
98
|
+
import json
|
|
99
|
+
import logging
|
|
100
|
+
import subprocess
|
|
101
|
+
import warnings
|
|
102
|
+
from collections import OrderedDict
|
|
103
|
+
from dataclasses import dataclass, field
|
|
104
|
+
from pathlib import Path
|
|
105
|
+
from typing import Any, ClassVar
|
|
106
|
+
|
|
107
|
+
import pyarrow as pa
|
|
108
|
+
import torch
|
|
109
|
+
import torchvision
|
|
110
|
+
from datasets.features.features import register_feature
|
|
111
|
+
from PIL import Image
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def get_safe_default_codec() -> str:
|
|
115
|
+
"""Get the default video codec backend, falling back to pyav if torchcodec is unavailable.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Backend name: "torchcodec" if available, otherwise "pyav".
|
|
119
|
+
"""
|
|
120
|
+
if importlib.util.find_spec("torchcodec"):
|
|
121
|
+
return "torchcodec"
|
|
122
|
+
else:
|
|
123
|
+
logging.warning(
|
|
124
|
+
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
|
125
|
+
)
|
|
126
|
+
return "pyav"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def decode_video_frames(
|
|
130
|
+
video_path: Path | str,
|
|
131
|
+
timestamps: list[float],
|
|
132
|
+
tolerance_s: float,
|
|
133
|
+
backend: str | None = None,
|
|
134
|
+
) -> torch.Tensor:
|
|
135
|
+
"""
|
|
136
|
+
Decodes video frames using the specified backend.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
video_path (Path): Path to the video file.
|
|
140
|
+
timestamps (list[float]): List of timestamps to extract frames.
|
|
141
|
+
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
|
142
|
+
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
torch.Tensor: Decoded frames.
|
|
146
|
+
|
|
147
|
+
Currently supports torchcodec on cpu and pyav.
|
|
148
|
+
"""
|
|
149
|
+
if backend is None:
|
|
150
|
+
backend = get_safe_default_codec()
|
|
151
|
+
if backend == "torchcodec":
|
|
152
|
+
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
|
153
|
+
elif backend in ["pyav", "video_reader"]:
|
|
154
|
+
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(f"Unsupported video backend: {backend}")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def decode_video_frames_torchvision(
|
|
160
|
+
video_path: Path | str,
|
|
161
|
+
timestamps: list[float],
|
|
162
|
+
tolerance_s: float,
|
|
163
|
+
backend: str = "pyav",
|
|
164
|
+
log_loaded_timestamps: bool = False,
|
|
165
|
+
) -> torch.Tensor:
|
|
166
|
+
"""Loads frames associated to the requested timestamps of a video
|
|
167
|
+
|
|
168
|
+
The backend can be either "pyav" (default) or "video_reader".
|
|
169
|
+
"video_reader" requires installing torchvision from source, see:
|
|
170
|
+
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
|
171
|
+
(note that you need to compile against ffmpeg<4.3)
|
|
172
|
+
|
|
173
|
+
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
|
174
|
+
For more info on video decoding, see `benchmark/video/README.md`
|
|
175
|
+
|
|
176
|
+
See torchvision doc for more info on these two backends:
|
|
177
|
+
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
|
178
|
+
|
|
179
|
+
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
|
180
|
+
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
|
181
|
+
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
182
|
+
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
183
|
+
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
184
|
+
"""
|
|
185
|
+
video_path = str(video_path)
|
|
186
|
+
|
|
187
|
+
# set backend
|
|
188
|
+
keyframes_only = False
|
|
189
|
+
torchvision.set_video_backend(backend)
|
|
190
|
+
if backend == "pyav":
|
|
191
|
+
keyframes_only = True # pyav doesnt support accuracte seek
|
|
192
|
+
|
|
193
|
+
# set a video stream reader
|
|
194
|
+
# TODO(rcadene): also load audio stream at the same time
|
|
195
|
+
reader = torchvision.io.VideoReader(video_path, "video")
|
|
196
|
+
|
|
197
|
+
# set the first and last requested timestamps
|
|
198
|
+
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
|
199
|
+
first_ts = min(timestamps)
|
|
200
|
+
last_ts = max(timestamps)
|
|
201
|
+
|
|
202
|
+
# access closest key frame of the first requested frame
|
|
203
|
+
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
|
204
|
+
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
|
205
|
+
reader.seek(first_ts, keyframes_only=keyframes_only)
|
|
206
|
+
|
|
207
|
+
# load all frames until last requested frame
|
|
208
|
+
loaded_frames = []
|
|
209
|
+
loaded_ts = []
|
|
210
|
+
for frame in reader:
|
|
211
|
+
current_ts = frame["pts"]
|
|
212
|
+
if log_loaded_timestamps:
|
|
213
|
+
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
|
214
|
+
loaded_frames.append(frame["data"])
|
|
215
|
+
loaded_ts.append(current_ts)
|
|
216
|
+
if current_ts >= last_ts:
|
|
217
|
+
break
|
|
218
|
+
|
|
219
|
+
if backend == "pyav":
|
|
220
|
+
reader.container.close()
|
|
221
|
+
|
|
222
|
+
reader = None
|
|
223
|
+
|
|
224
|
+
query_ts = torch.tensor(timestamps)
|
|
225
|
+
loaded_ts = torch.tensor(loaded_ts)
|
|
226
|
+
|
|
227
|
+
# compute distances between each query timestamp and timestamps of all loaded frames
|
|
228
|
+
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
|
229
|
+
min_, argmin_ = dist.min(1)
|
|
230
|
+
|
|
231
|
+
is_within_tol = min_ < tolerance_s
|
|
232
|
+
assert is_within_tol.all(), (
|
|
233
|
+
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
|
234
|
+
"It means that the closest frame that can be loaded from the video is too far away in time."
|
|
235
|
+
"This might be due to synchronization issues with timestamps during data collection."
|
|
236
|
+
"To be safe, we advise to ignore this item during training."
|
|
237
|
+
f"\nqueried timestamps: {query_ts}"
|
|
238
|
+
f"\nloaded timestamps: {loaded_ts}"
|
|
239
|
+
f"\nvideo: {video_path}"
|
|
240
|
+
f"\nbackend: {backend}"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# get closest frames to the query timestamps
|
|
244
|
+
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
|
245
|
+
closest_ts = loaded_ts[argmin_]
|
|
246
|
+
|
|
247
|
+
if log_loaded_timestamps:
|
|
248
|
+
logging.info(f"{closest_ts=}")
|
|
249
|
+
|
|
250
|
+
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
|
251
|
+
closest_frames = closest_frames.type(torch.float32) / 255
|
|
252
|
+
|
|
253
|
+
assert len(timestamps) == len(closest_frames)
|
|
254
|
+
return closest_frames
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def decode_video_frames_torchcodec(
|
|
258
|
+
video_path: Path | str,
|
|
259
|
+
timestamps: list[float],
|
|
260
|
+
tolerance_s: float,
|
|
261
|
+
device: str = "cpu",
|
|
262
|
+
log_loaded_timestamps: bool = False,
|
|
263
|
+
) -> torch.Tensor:
|
|
264
|
+
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
|
265
|
+
|
|
266
|
+
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
|
267
|
+
|
|
268
|
+
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
|
269
|
+
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
|
270
|
+
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
271
|
+
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
272
|
+
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
if importlib.util.find_spec("torchcodec"):
|
|
276
|
+
from torchcodec.decoders import VideoDecoder
|
|
277
|
+
else:
|
|
278
|
+
raise ImportError("torchcodec is required but not available.")
|
|
279
|
+
|
|
280
|
+
# initialize video decoder
|
|
281
|
+
decoder = VideoDecoder(video_path, device=device, seek_mode="exact")
|
|
282
|
+
loaded_frames = []
|
|
283
|
+
loaded_ts = []
|
|
284
|
+
# get metadata for frame information
|
|
285
|
+
metadata = decoder.metadata
|
|
286
|
+
average_fps = metadata.average_fps
|
|
287
|
+
|
|
288
|
+
# convert timestamps to frame indices
|
|
289
|
+
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
|
290
|
+
|
|
291
|
+
# retrieve frames based on indices
|
|
292
|
+
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
|
293
|
+
|
|
294
|
+
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
|
295
|
+
loaded_frames.append(frame)
|
|
296
|
+
loaded_ts.append(pts.item())
|
|
297
|
+
if log_loaded_timestamps:
|
|
298
|
+
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
|
299
|
+
|
|
300
|
+
query_ts = torch.tensor(timestamps)
|
|
301
|
+
loaded_ts = torch.tensor(loaded_ts)
|
|
302
|
+
|
|
303
|
+
# compute distances between each query timestamp and loaded timestamps
|
|
304
|
+
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
|
305
|
+
min_, argmin_ = dist.min(1)
|
|
306
|
+
|
|
307
|
+
is_within_tol = min_ < tolerance_s
|
|
308
|
+
assert is_within_tol.all(), (
|
|
309
|
+
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
|
310
|
+
"It means that the closest frame that can be loaded from the video is too far away in time."
|
|
311
|
+
"This might be due to synchronization issues with timestamps during data collection."
|
|
312
|
+
"To be safe, we advise to ignore this item during training."
|
|
313
|
+
f"\nqueried timestamps: {query_ts}"
|
|
314
|
+
f"\nloaded timestamps: {loaded_ts}"
|
|
315
|
+
f"\nvideo: {video_path}"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# get closest frames to the query timestamps
|
|
319
|
+
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
|
320
|
+
closest_ts = loaded_ts[argmin_]
|
|
321
|
+
|
|
322
|
+
if log_loaded_timestamps:
|
|
323
|
+
logging.info(f"{closest_ts=}")
|
|
324
|
+
|
|
325
|
+
# convert to float32 in [0,1] range (channel first)
|
|
326
|
+
closest_frames = closest_frames.type(torch.float32) / 255
|
|
327
|
+
|
|
328
|
+
assert len(timestamps) == len(closest_frames)
|
|
329
|
+
return closest_frames
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def encode_video_frames(
|
|
333
|
+
imgs_dir: Path | str,
|
|
334
|
+
video_path: Path | str,
|
|
335
|
+
fps: int,
|
|
336
|
+
vcodec: str = "libsvtav1",
|
|
337
|
+
pix_fmt: str = "yuv420p",
|
|
338
|
+
g: int | None = 2,
|
|
339
|
+
crf: int | None = 30,
|
|
340
|
+
fast_decode: int = 0,
|
|
341
|
+
log_level: str | None = "error",
|
|
342
|
+
overwrite: bool = False,
|
|
343
|
+
) -> None:
|
|
344
|
+
"""Encode a sequence of images into a video file using ffmpeg.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
imgs_dir: Directory containing sequentially numbered PNG frames
|
|
348
|
+
(frame_000000.png, frame_000001.png, etc.).
|
|
349
|
+
video_path: Output path for the encoded video file.
|
|
350
|
+
fps: Frames per second for the output video.
|
|
351
|
+
vcodec: Video codec to use. Defaults to "libsvtav1".
|
|
352
|
+
pix_fmt: Pixel format. Defaults to "yuv420p".
|
|
353
|
+
g: GOP (Group of Pictures) size. Defaults to 2.
|
|
354
|
+
crf: Constant Rate Factor for quality control. Defaults to 30.
|
|
355
|
+
fast_decode: Fast decode parameter for libsvtav1. Defaults to 0.
|
|
356
|
+
log_level: FFmpeg log level. Defaults to "error".
|
|
357
|
+
overwrite: Whether to overwrite existing video file. Defaults to False.
|
|
358
|
+
|
|
359
|
+
Raises:
|
|
360
|
+
OSError: If video encoding fails or output file is not created.
|
|
361
|
+
|
|
362
|
+
Note:
|
|
363
|
+
More info on ffmpeg arguments tuning on `benchmark/video/README.md`
|
|
364
|
+
"""
|
|
365
|
+
video_path = Path(video_path)
|
|
366
|
+
imgs_dir = Path(imgs_dir)
|
|
367
|
+
video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
368
|
+
|
|
369
|
+
ffmpeg_args = OrderedDict(
|
|
370
|
+
[
|
|
371
|
+
("-f", "image2"),
|
|
372
|
+
("-r", str(fps)),
|
|
373
|
+
("-i", str(imgs_dir / "frame_%06d.png")),
|
|
374
|
+
("-vcodec", vcodec),
|
|
375
|
+
("-pix_fmt", pix_fmt),
|
|
376
|
+
]
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if g is not None:
|
|
380
|
+
ffmpeg_args["-g"] = str(g)
|
|
381
|
+
|
|
382
|
+
if crf is not None:
|
|
383
|
+
ffmpeg_args["-crf"] = str(crf)
|
|
384
|
+
|
|
385
|
+
if fast_decode:
|
|
386
|
+
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
|
387
|
+
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
|
388
|
+
ffmpeg_args[key] = value
|
|
389
|
+
|
|
390
|
+
if log_level is not None:
|
|
391
|
+
ffmpeg_args["-loglevel"] = str(log_level)
|
|
392
|
+
|
|
393
|
+
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
|
394
|
+
if overwrite:
|
|
395
|
+
ffmpeg_args.append("-y")
|
|
396
|
+
|
|
397
|
+
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
|
398
|
+
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
|
399
|
+
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
|
400
|
+
|
|
401
|
+
if not video_path.exists():
|
|
402
|
+
raise OSError(
|
|
403
|
+
f"Video encoding did not work. File not found: {video_path}. "
|
|
404
|
+
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
@dataclass
|
|
409
|
+
class VideoFrame:
|
|
410
|
+
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
|
411
|
+
"""
|
|
412
|
+
Provides a type for a dataset containing video frames.
|
|
413
|
+
|
|
414
|
+
Example:
|
|
415
|
+
|
|
416
|
+
```python
|
|
417
|
+
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
|
418
|
+
features = {"image": VideoFrame()}
|
|
419
|
+
Dataset.from_dict(data_dict, features=Features(features))
|
|
420
|
+
```
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
|
|
424
|
+
_type: str = field(default="VideoFrame", init=False, repr=False)
|
|
425
|
+
|
|
426
|
+
def __call__(self):
|
|
427
|
+
return self.pa_type
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
with warnings.catch_warnings():
|
|
431
|
+
warnings.filterwarnings(
|
|
432
|
+
"ignore",
|
|
433
|
+
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
|
434
|
+
category=UserWarning,
|
|
435
|
+
)
|
|
436
|
+
# to make VideoFrame available in HuggingFace `datasets`
|
|
437
|
+
register_feature(VideoFrame, "VideoFrame")
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def get_audio_info(video_path: Path | str) -> dict:
|
|
441
|
+
"""Extract audio stream information from a video file using ffprobe.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
video_path: Path to the video file.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Dictionary containing audio information:
|
|
448
|
+
- has_audio: Boolean indicating if audio stream exists.
|
|
449
|
+
- audio.channels: Number of audio channels (if available).
|
|
450
|
+
- audio.codec: Audio codec name (if available).
|
|
451
|
+
- audio.bit_rate: Bit rate in bits per second (if available).
|
|
452
|
+
- audio.sample_rate: Sample rate in Hz (if available).
|
|
453
|
+
- audio.bit_depth: Bit depth (if available).
|
|
454
|
+
- audio.channel_layout: Channel layout (if available).
|
|
455
|
+
|
|
456
|
+
Raises:
|
|
457
|
+
RuntimeError: If ffprobe command fails.
|
|
458
|
+
"""
|
|
459
|
+
ffprobe_audio_cmd = [
|
|
460
|
+
"ffprobe",
|
|
461
|
+
"-v",
|
|
462
|
+
"error",
|
|
463
|
+
"-select_streams",
|
|
464
|
+
"a:0",
|
|
465
|
+
"-show_entries",
|
|
466
|
+
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
|
467
|
+
"-of",
|
|
468
|
+
"json",
|
|
469
|
+
str(video_path),
|
|
470
|
+
]
|
|
471
|
+
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
472
|
+
if result.returncode != 0:
|
|
473
|
+
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
|
474
|
+
|
|
475
|
+
info = json.loads(result.stdout)
|
|
476
|
+
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
|
477
|
+
if audio_stream_info is None:
|
|
478
|
+
return {"has_audio": False}
|
|
479
|
+
|
|
480
|
+
# Return the information, defaulting to None if no audio stream is present
|
|
481
|
+
return {
|
|
482
|
+
"has_audio": True,
|
|
483
|
+
"audio.channels": audio_stream_info.get("channels", None),
|
|
484
|
+
"audio.codec": audio_stream_info.get("codec_name", None),
|
|
485
|
+
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
|
486
|
+
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
|
487
|
+
if audio_stream_info.get("sample_rate")
|
|
488
|
+
else None,
|
|
489
|
+
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
|
490
|
+
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def get_video_info(video_path: Path | str) -> dict:
|
|
495
|
+
"""Extract video stream information from a video file using ffprobe.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
video_path: Path to the video file.
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
Dictionary containing video and audio information:
|
|
502
|
+
- video.fps: Frames per second.
|
|
503
|
+
- video.height: Video height in pixels.
|
|
504
|
+
- video.width: Video width in pixels.
|
|
505
|
+
- video.channels: Number of pixel channels.
|
|
506
|
+
- video.codec: Video codec name.
|
|
507
|
+
- video.pix_fmt: Pixel format.
|
|
508
|
+
- video.is_depth_map: Whether video is a depth map.
|
|
509
|
+
- Plus all fields from get_audio_info().
|
|
510
|
+
|
|
511
|
+
Raises:
|
|
512
|
+
RuntimeError: If ffprobe command fails.
|
|
513
|
+
"""
|
|
514
|
+
ffprobe_video_cmd = [
|
|
515
|
+
"ffprobe",
|
|
516
|
+
"-v",
|
|
517
|
+
"error",
|
|
518
|
+
"-select_streams",
|
|
519
|
+
"v:0",
|
|
520
|
+
"-show_entries",
|
|
521
|
+
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
|
522
|
+
"-of",
|
|
523
|
+
"json",
|
|
524
|
+
str(video_path),
|
|
525
|
+
]
|
|
526
|
+
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
527
|
+
if result.returncode != 0:
|
|
528
|
+
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
|
529
|
+
|
|
530
|
+
info = json.loads(result.stdout)
|
|
531
|
+
video_stream_info = info["streams"][0]
|
|
532
|
+
|
|
533
|
+
# Calculate fps from r_frame_rate
|
|
534
|
+
r_frame_rate = video_stream_info["r_frame_rate"]
|
|
535
|
+
num, denom = map(int, r_frame_rate.split("/"))
|
|
536
|
+
fps = num / denom
|
|
537
|
+
|
|
538
|
+
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
|
539
|
+
|
|
540
|
+
video_info = {
|
|
541
|
+
"video.fps": fps,
|
|
542
|
+
"video.height": video_stream_info["height"],
|
|
543
|
+
"video.width": video_stream_info["width"],
|
|
544
|
+
"video.channels": pixel_channels,
|
|
545
|
+
"video.codec": video_stream_info["codec_name"],
|
|
546
|
+
"video.pix_fmt": video_stream_info["pix_fmt"],
|
|
547
|
+
"video.is_depth_map": False,
|
|
548
|
+
**get_audio_info(video_path),
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
return video_info
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def get_video_pixel_channels(pix_fmt: str) -> int:
|
|
555
|
+
"""Determine the number of pixel channels from a pixel format string.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
pix_fmt: Pixel format string (e.g., "yuv420p", "rgb24").
|
|
559
|
+
|
|
560
|
+
Returns:
|
|
561
|
+
Number of channels (1, 3, or 4).
|
|
562
|
+
|
|
563
|
+
Raises:
|
|
564
|
+
ValueError: If pixel format is unknown.
|
|
565
|
+
"""
|
|
566
|
+
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
|
567
|
+
return 1
|
|
568
|
+
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
|
569
|
+
return 4
|
|
570
|
+
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
|
571
|
+
return 3
|
|
572
|
+
else:
|
|
573
|
+
raise ValueError("Unknown format")
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def get_image_pixel_channels(image: Image) -> int:
|
|
577
|
+
"""Determine the number of pixel channels from a PIL Image mode.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
image: PIL Image object.
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
Number of channels (1, 2, 3, or 4).
|
|
584
|
+
|
|
585
|
+
Raises:
|
|
586
|
+
ValueError: If image mode is unknown.
|
|
587
|
+
"""
|
|
588
|
+
if image.mode == "L":
|
|
589
|
+
return 1 # Grayscale
|
|
590
|
+
elif image.mode == "LA":
|
|
591
|
+
return 2 # Grayscale + Alpha
|
|
592
|
+
elif image.mode == "RGB":
|
|
593
|
+
return 3 # RGB
|
|
594
|
+
elif image.mode == "RGBA":
|
|
595
|
+
return 4 # RGBA
|
|
596
|
+
else:
|
|
597
|
+
raise ValueError("Unknown format")
|
opentau/envs/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
r"""This package includes environments for training and evaluating policies. Only LIBERO is supported for now."""
|
|
17
|
+
|
|
18
|
+
from .configs import EnvConfig # noqa: F401
|