opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Utilities for file I/O operations.
|
|
18
|
+
|
|
19
|
+
This module provides functions for reading and writing JSON files, saving videos,
|
|
20
|
+
and deserializing JSON data into structured objects with type checking.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import json
|
|
24
|
+
import warnings
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import TypeVar
|
|
27
|
+
|
|
28
|
+
import imageio
|
|
29
|
+
|
|
30
|
+
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
|
31
|
+
T = TypeVar("T", bound=JsonLike)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def write_video(video_path: str | Path, stacked_frames: list, fps: float) -> None:
|
|
35
|
+
"""Write a list of frames to a video file.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
video_path: Path where the video file will be saved.
|
|
39
|
+
stacked_frames: List of image frames to write.
|
|
40
|
+
fps: Frames per second for the output video.
|
|
41
|
+
"""
|
|
42
|
+
# Filter out DeprecationWarnings raised from pkg_resources
|
|
43
|
+
with warnings.catch_warnings():
|
|
44
|
+
warnings.filterwarnings(
|
|
45
|
+
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
|
46
|
+
)
|
|
47
|
+
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
|
51
|
+
"""Load JSON data and recursively fill an object with matching structure.
|
|
52
|
+
|
|
53
|
+
Loads the JSON data from fpath and recursively fills obj with the
|
|
54
|
+
corresponding values (strictly matching structure and types).
|
|
55
|
+
Tuples in obj are expected to be lists in the JSON data, which will be
|
|
56
|
+
converted back into tuples.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
fpath: Path to the JSON file to load.
|
|
60
|
+
obj: Template object with the desired structure and types.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Object with the same structure as obj, filled with values from the JSON file.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
TypeError: If structure or types don't match between JSON and obj.
|
|
67
|
+
ValueError: If dictionary keys or list/tuple lengths don't match.
|
|
68
|
+
"""
|
|
69
|
+
with open(fpath, encoding="utf-8") as f:
|
|
70
|
+
data = json.load(f)
|
|
71
|
+
|
|
72
|
+
def _deserialize(target, source):
|
|
73
|
+
"""
|
|
74
|
+
Recursively overwrite the structure in `target` with data from `source`,
|
|
75
|
+
performing strict checks on structure and type.
|
|
76
|
+
Returns the updated version of `target` (especially important for tuples).
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
# If the target is a dictionary, source must be a dictionary as well.
|
|
80
|
+
if isinstance(target, dict):
|
|
81
|
+
if not isinstance(source, dict):
|
|
82
|
+
raise TypeError(f"Type mismatch: expected dict, got {type(source)}")
|
|
83
|
+
|
|
84
|
+
# Check that they have exactly the same set of keys.
|
|
85
|
+
if target.keys() != source.keys():
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Recursively update each key.
|
|
91
|
+
for k in target:
|
|
92
|
+
target[k] = _deserialize(target[k], source[k])
|
|
93
|
+
|
|
94
|
+
return target
|
|
95
|
+
|
|
96
|
+
# If the target is a list, source must be a list as well.
|
|
97
|
+
elif isinstance(target, list):
|
|
98
|
+
if not isinstance(source, list):
|
|
99
|
+
raise TypeError(f"Type mismatch: expected list, got {type(source)}")
|
|
100
|
+
|
|
101
|
+
# Check length
|
|
102
|
+
if len(target) != len(source):
|
|
103
|
+
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
|
|
104
|
+
|
|
105
|
+
# Recursively update each element.
|
|
106
|
+
for i in range(len(target)):
|
|
107
|
+
target[i] = _deserialize(target[i], source[i])
|
|
108
|
+
|
|
109
|
+
return target
|
|
110
|
+
|
|
111
|
+
# If the target is a tuple, the source must be a list in JSON,
|
|
112
|
+
# which we'll convert back to a tuple.
|
|
113
|
+
elif isinstance(target, tuple):
|
|
114
|
+
if not isinstance(source, list):
|
|
115
|
+
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
|
|
116
|
+
|
|
117
|
+
if len(target) != len(source):
|
|
118
|
+
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
|
|
119
|
+
|
|
120
|
+
# Convert each element, forming a new tuple.
|
|
121
|
+
converted_items = []
|
|
122
|
+
for t_item, s_item in zip(target, source, strict=False):
|
|
123
|
+
converted_items.append(_deserialize(t_item, s_item))
|
|
124
|
+
|
|
125
|
+
# Return a brand new tuple (tuples are immutable in Python).
|
|
126
|
+
return tuple(converted_items)
|
|
127
|
+
|
|
128
|
+
# Otherwise, we're dealing with a "primitive" (int, float, str, bool, None).
|
|
129
|
+
else:
|
|
130
|
+
# Check the exact type. If these must match 1:1, do:
|
|
131
|
+
if type(target) is not type(source):
|
|
132
|
+
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
|
|
133
|
+
return source
|
|
134
|
+
|
|
135
|
+
# Perform the in-place/recursive deserialization
|
|
136
|
+
updated_obj = _deserialize(obj, data)
|
|
137
|
+
return updated_obj
|
opentau/utils/libero.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Utilities for working with the LIBERO robotics environment.
|
|
16
|
+
|
|
17
|
+
This module provides functions for converting LIBERO observations to PyTorch tensors,
|
|
18
|
+
summarizing LIBERO evaluation results, and recording observations from LIBERO environments.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import logging
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
import imageio
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
from einops import rearrange
|
|
28
|
+
from robosuite.utils.transform_utils import quat2axisangle
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def rotate_numpy_image(image: np.ndarray) -> np.ndarray:
|
|
32
|
+
"""Rotate and normalize a numpy image array.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
image: Input image array in HWC format with values in [0, 255].
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Rotated and normalized image array in CHW format with values in [0, 1].
|
|
39
|
+
"""
|
|
40
|
+
image = image.astype(float) / 255.0
|
|
41
|
+
image = np.rot90(image, 2)
|
|
42
|
+
return rearrange(image, "H W C -> C H W")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _libero2np(obs: dict[str, np.ndarray], cfg) -> dict[str, str | np.ndarray]:
|
|
46
|
+
"""Convert LIBERO observation dictionary to numpy format.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
obs: LIBERO observation dictionary containing robot state and images.
|
|
50
|
+
cfg: Configuration object with task language, state dimensions, etc.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Dictionary with converted observations in numpy format, including camera
|
|
54
|
+
images, state, prompt, and padding flags.
|
|
55
|
+
"""
|
|
56
|
+
eef_pos = obs["robot0_eef_pos"]
|
|
57
|
+
eef_angle = quat2axisangle(obs["robot0_eef_quat"])
|
|
58
|
+
gripper_pos = obs["robot0_gripper_qpos"]
|
|
59
|
+
|
|
60
|
+
state = np.hstack((eef_pos, eef_angle, gripper_pos))
|
|
61
|
+
|
|
62
|
+
agent_view = rotate_numpy_image(obs["agentview_image"])
|
|
63
|
+
wrist_view = rotate_numpy_image(obs["robot0_eye_in_hand_image"])
|
|
64
|
+
|
|
65
|
+
return {
|
|
66
|
+
"camera0": agent_view,
|
|
67
|
+
"camera1": wrist_view,
|
|
68
|
+
"prompt": cfg.libero.task.language,
|
|
69
|
+
"state": np.pad(state, (0, cfg.max_state_dim - len(state))),
|
|
70
|
+
"img_is_pad": np.zeros(cfg.num_cams, dtype=bool),
|
|
71
|
+
"action_is_pad": np.zeros(cfg.action_chunk, dtype=bool),
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _np2torch(
|
|
76
|
+
np_input: dict[str, str | np.ndarray], device: str, dtype: torch.dtype
|
|
77
|
+
) -> dict[str, str | torch.Tensor]:
|
|
78
|
+
"""Convert numpy arrays in dictionary to PyTorch tensors.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
np_input: Dictionary containing numpy arrays and strings.
|
|
82
|
+
device: Target device for tensors (e.g., 'cuda', 'cpu').
|
|
83
|
+
dtype: Target dtype for floating point tensors.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Dictionary with numpy arrays converted to PyTorch tensors on the
|
|
87
|
+
specified device. String values are preserved as-is.
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
TypeError: If a value type is not supported (not str or np.ndarray).
|
|
91
|
+
"""
|
|
92
|
+
torch_input = {}
|
|
93
|
+
for k, v in np_input.items():
|
|
94
|
+
if isinstance(v, str):
|
|
95
|
+
torch_input[k] = v
|
|
96
|
+
elif isinstance(v, np.ndarray):
|
|
97
|
+
# .copy() ensures the array is contiguous for PyTorch to use it
|
|
98
|
+
tensor = torch.tensor(v.copy())
|
|
99
|
+
if tensor.dtype.is_floating_point:
|
|
100
|
+
tensor = tensor.to(dtype=dtype)
|
|
101
|
+
torch_input[k] = tensor.to(device)
|
|
102
|
+
else:
|
|
103
|
+
raise TypeError(f"Unsupported type {type(v)} for key {k}.")
|
|
104
|
+
return torch_input
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def libero2torch(
|
|
108
|
+
obs: dict[str, np.ndarray], cfg, device: str, dtype: torch.dtype
|
|
109
|
+
) -> dict[str, str | torch.Tensor]:
|
|
110
|
+
"""Convert LIBERO observation to PyTorch tensors.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
obs: LIBERO observation dictionary containing robot state and images.
|
|
114
|
+
cfg: Configuration object with task language, state dimensions, etc.
|
|
115
|
+
device: Target device for tensors (e.g., 'cuda', 'cpu').
|
|
116
|
+
dtype: Target dtype for floating point tensors.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Dictionary with observations converted to PyTorch tensors on the
|
|
120
|
+
specified device, including camera images, state, prompt, and padding flags.
|
|
121
|
+
"""
|
|
122
|
+
np_input = _libero2np(obs, cfg)
|
|
123
|
+
torch_input = _np2torch(np_input, device, dtype)
|
|
124
|
+
return torch_input
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def summarize_libero_results(results: list[int]) -> dict:
|
|
128
|
+
"""Summarize LIBERO evaluation results.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
results: List of integer results where:
|
|
132
|
+
- Positive values indicate success (number of steps taken).
|
|
133
|
+
- -1 indicates failure.
|
|
134
|
+
- -2 indicates crash.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Dictionary containing summary statistics including success/failure/crash
|
|
138
|
+
rates, counts, indices, and average steps taken for successful episodes.
|
|
139
|
+
"""
|
|
140
|
+
if not results:
|
|
141
|
+
return {"message": "No results to summarize."}
|
|
142
|
+
|
|
143
|
+
success_indices = [i for i, r in enumerate(results) if r >= 0]
|
|
144
|
+
failure_indices = [i for i, r in enumerate(results) if r == -1]
|
|
145
|
+
crashed_indices = [i for i, r in enumerate(results) if r == -2]
|
|
146
|
+
|
|
147
|
+
success_rate = len(success_indices) / len(results)
|
|
148
|
+
failure_rate = len(failure_indices) / len(results)
|
|
149
|
+
crashed_rate = len(crashed_indices) / len(results)
|
|
150
|
+
|
|
151
|
+
avg_steps_taken = float(np.mean([r for r in results if r >= 0])) if success_indices else None
|
|
152
|
+
|
|
153
|
+
return {
|
|
154
|
+
"total_simulations": len(results),
|
|
155
|
+
"success_indices": success_indices,
|
|
156
|
+
"failure_indices": failure_indices,
|
|
157
|
+
"crashed_indices": crashed_indices,
|
|
158
|
+
"success_count": len(success_indices),
|
|
159
|
+
"failure_count": len(failure_indices),
|
|
160
|
+
"crashed_count": len(crashed_indices),
|
|
161
|
+
"success_rate": success_rate,
|
|
162
|
+
"failure_rate": failure_rate,
|
|
163
|
+
"crashed_rate": crashed_rate,
|
|
164
|
+
"steps_taken": results,
|
|
165
|
+
"avg_steps_taken_until_success": avg_steps_taken,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class LiberoObservationRecorder:
|
|
170
|
+
"""Context manager for recording LIBERO observations to video files.
|
|
171
|
+
|
|
172
|
+
This class is not multi-processing safe. Each process should use a different
|
|
173
|
+
(folder, camera_name) pair.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
folder: Directory path where video files will be saved. If None, recording
|
|
177
|
+
is disabled.
|
|
178
|
+
camera_names: List of camera names to record. If None, no cameras are recorded.
|
|
179
|
+
fps: Frames per second for the output videos. Defaults to 10.
|
|
180
|
+
extension: Video file extension. Defaults to "mp4".
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(self, folder, camera_names=None, fps=10, extension="mp4"):
|
|
184
|
+
if folder is None:
|
|
185
|
+
logging.debug("No folder specified for video recording. Skipping.")
|
|
186
|
+
self.writers = []
|
|
187
|
+
self.camera_names = []
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
self.camera_names = camera_names or []
|
|
191
|
+
folder = Path(folder)
|
|
192
|
+
Path(folder).mkdir(parents=True, exist_ok=True)
|
|
193
|
+
video_files = [folder / f"{cam}.{extension}" for cam in self.camera_names]
|
|
194
|
+
logging.debug("Creating video files: %s", video_files)
|
|
195
|
+
self.writers = [imageio.get_writer(vf, fps=fps) for vf in video_files]
|
|
196
|
+
|
|
197
|
+
def __enter__(self):
|
|
198
|
+
return self
|
|
199
|
+
|
|
200
|
+
def record(self, obs):
|
|
201
|
+
"""Record a single observation frame.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
obs: Observation dictionary containing camera images keyed by camera name.
|
|
205
|
+
"""
|
|
206
|
+
for writer, camera in zip(self.writers, self.camera_names, strict=True):
|
|
207
|
+
writer.append_data(np.rot90(obs[camera], k=2))
|
|
208
|
+
|
|
209
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
210
|
+
logging.debug("Closing video writers.")
|
|
211
|
+
for writer in self.writers:
|
|
212
|
+
writer.close()
|
|
213
|
+
logging.debug("Video writers closed.")
|
|
214
|
+
return False
|