opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import subprocess
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
import opentau.scripts.train as train_script
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def main():
|
|
24
|
+
parser = argparse.ArgumentParser(
|
|
25
|
+
description="Launch OpenTau training with Accelerate",
|
|
26
|
+
usage="opentau-train [--accelerate-config CONFIG] [TRAINING_ARGS]",
|
|
27
|
+
)
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--accelerate-config", type=str, help="Path to accelerate config file (yaml)", default=None
|
|
30
|
+
)
|
|
31
|
+
# We use parse_known_args so that all other arguments are collected
|
|
32
|
+
# These will be passed to the training script
|
|
33
|
+
args, unknown_args = parser.parse_known_args()
|
|
34
|
+
|
|
35
|
+
# Base command
|
|
36
|
+
cmd = ["accelerate", "launch"]
|
|
37
|
+
|
|
38
|
+
# Add accelerate config if provided
|
|
39
|
+
if args.accelerate_config:
|
|
40
|
+
cmd.extend(["--config_file", args.accelerate_config])
|
|
41
|
+
|
|
42
|
+
# Add the path to the training script
|
|
43
|
+
# We resolve the path to ensure it's absolute
|
|
44
|
+
train_script_path = Path(train_script.__file__).resolve()
|
|
45
|
+
cmd.append(str(train_script_path))
|
|
46
|
+
|
|
47
|
+
# Add all other arguments (passed to the training script)
|
|
48
|
+
cmd.extend(unknown_args)
|
|
49
|
+
|
|
50
|
+
# Print the command for transparency
|
|
51
|
+
print(f"Executing: {' '.join(cmd)}")
|
|
52
|
+
|
|
53
|
+
# Replace the current process with the accelerate launch command
|
|
54
|
+
try:
|
|
55
|
+
subprocess.run(cmd, check=True)
|
|
56
|
+
except subprocess.CalledProcessError as e:
|
|
57
|
+
sys.exit(e.returncode)
|
|
58
|
+
except KeyboardInterrupt:
|
|
59
|
+
sys.exit(130)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
if __name__ == "__main__":
|
|
63
|
+
main()
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import ctypes
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import signal
|
|
19
|
+
import sys
|
|
20
|
+
import threading
|
|
21
|
+
import time
|
|
22
|
+
from collections import deque
|
|
23
|
+
from dataclasses import asdict, dataclass
|
|
24
|
+
from multiprocessing import Array, Pipe, Process, SimpleQueue
|
|
25
|
+
from multiprocessing.connection import Connection, wait
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from pprint import pformat
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
import psutil
|
|
31
|
+
import torch
|
|
32
|
+
from einops import rearrange
|
|
33
|
+
from torch.utils.data._utils.collate import default_collate
|
|
34
|
+
|
|
35
|
+
from opentau.configs import parser
|
|
36
|
+
from opentau.configs.libero import TrainConfigWithLiberoEval
|
|
37
|
+
from opentau.policies.factory import get_policy_class
|
|
38
|
+
from opentau.utils.libero import LiberoObservationRecorder, summarize_libero_results
|
|
39
|
+
from opentau.utils.libero import _libero2np as libero2np
|
|
40
|
+
from opentau.utils.libero import _np2torch as np2torch
|
|
41
|
+
from opentau.utils.monkey_patch import gym_is_gymnasium_patch
|
|
42
|
+
from opentau.utils.random_utils import set_seed
|
|
43
|
+
from opentau.utils.utils import auto_torch_device
|
|
44
|
+
|
|
45
|
+
# Sent by client process to indicate simulation completion and signal that the pipe is to be closed
|
|
46
|
+
SENTINEL = "<SENTINEL>"
|
|
47
|
+
|
|
48
|
+
LIBERO_ACTION_DIM = 7
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class Config(TrainConfigWithLiberoEval):
|
|
53
|
+
parallel_simulation_count: int = 4
|
|
54
|
+
max_wait_sec: float = 1.0
|
|
55
|
+
logging_dir: str | None = None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class Request:
|
|
60
|
+
r"""Request sent from the CPU LIBERO simulation process to the GPU policy."""
|
|
61
|
+
|
|
62
|
+
sim_id: int
|
|
63
|
+
step_id: int
|
|
64
|
+
observation: dict[str, np.ndarray | str]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class Response:
|
|
69
|
+
r"""Response sent from the GPU policy to the CPU LIBERO simulation process."""
|
|
70
|
+
|
|
71
|
+
chunked_action: np.ndarray
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ConnectionBuffer:
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
conns: list[Connection],
|
|
78
|
+
max_wait_sec: float,
|
|
79
|
+
max_batch_size: int,
|
|
80
|
+
device: str,
|
|
81
|
+
dtype: torch.dtype,
|
|
82
|
+
):
|
|
83
|
+
r"""Gathers a batch of inputs. Wait for no more than `max_wait_time` seconds,
|
|
84
|
+
or until `max_batch_size` is reached."""
|
|
85
|
+
self.conns = conns
|
|
86
|
+
self.max_wait = max_wait_sec
|
|
87
|
+
self.max_batch = max_batch_size
|
|
88
|
+
self.device = device
|
|
89
|
+
self.dtype = dtype
|
|
90
|
+
self.batch_inputs = []
|
|
91
|
+
self.response_list = []
|
|
92
|
+
self.last_yield_time = None
|
|
93
|
+
|
|
94
|
+
def _should_yield(self):
|
|
95
|
+
# Don't yield empty batches
|
|
96
|
+
if not self.batch_inputs:
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
return (
|
|
100
|
+
len(self.batch_inputs) >= self.max_batch
|
|
101
|
+
or time.monotonic() - self.last_yield_time >= self.max_wait
|
|
102
|
+
or not self.conns
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def get_batch(self):
|
|
106
|
+
self.last_yield_time = time.monotonic()
|
|
107
|
+
|
|
108
|
+
while self.conns or self.batch_inputs:
|
|
109
|
+
timeout = self.last_yield_time + self.max_wait - time.monotonic()
|
|
110
|
+
selected = wait(self.conns, timeout=max(timeout, 0.0)) if self.conns else []
|
|
111
|
+
for ready in selected:
|
|
112
|
+
try:
|
|
113
|
+
req = ready.recv()
|
|
114
|
+
if req != SENTINEL:
|
|
115
|
+
xs = np2torch(req.observation, self.device, self.dtype)
|
|
116
|
+
except Exception as e: # In case the simulation process crashed
|
|
117
|
+
logging.error(str(e))
|
|
118
|
+
req = SENTINEL
|
|
119
|
+
|
|
120
|
+
if req == SENTINEL:
|
|
121
|
+
logging.debug("Removing connection")
|
|
122
|
+
self.conns.remove(ready)
|
|
123
|
+
ready.close()
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
logging.debug(f"Received a request from sim {req.sim_id} at step {req.step_id}")
|
|
127
|
+
|
|
128
|
+
self.batch_inputs.append(xs)
|
|
129
|
+
self.response_list.append(ready)
|
|
130
|
+
if self._should_yield():
|
|
131
|
+
break
|
|
132
|
+
|
|
133
|
+
if self._should_yield():
|
|
134
|
+
bi, br = self.batch_inputs, self.response_list
|
|
135
|
+
self.batch_inputs, self.response_list = [], []
|
|
136
|
+
self.last_yield_time = time.monotonic()
|
|
137
|
+
yield bi, br
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def start_parent_check_thread():
|
|
141
|
+
def is_process_active(pid):
|
|
142
|
+
try:
|
|
143
|
+
process = psutil.Process(pid)
|
|
144
|
+
return process.is_running() and process.status() != psutil.STATUS_ZOMBIE
|
|
145
|
+
except psutil.NoSuchProcess:
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
def kill_child_processes(parent_pid):
|
|
149
|
+
parent = psutil.Process(parent_pid)
|
|
150
|
+
for child in parent.children(recursive=True):
|
|
151
|
+
try:
|
|
152
|
+
os.kill(child.pid, signal.SIGKILL)
|
|
153
|
+
logging.warning(f"Killed pid {child.pid}")
|
|
154
|
+
except BaseException as e:
|
|
155
|
+
logging.warning(f"Killing pid {child.pid} failed {str(e)}")
|
|
156
|
+
|
|
157
|
+
def check_parent_alive():
|
|
158
|
+
parent_pid = os.getppid()
|
|
159
|
+
while True:
|
|
160
|
+
if not is_process_active(parent_pid):
|
|
161
|
+
logging.warning(f"Parent is dead, kill self {os.getpid()}")
|
|
162
|
+
kill_child_processes(os.getpid())
|
|
163
|
+
os.kill(os.getpid(), signal.SIGKILL)
|
|
164
|
+
|
|
165
|
+
time.sleep(10)
|
|
166
|
+
|
|
167
|
+
thread = threading.Thread(target=check_parent_alive, daemon=True)
|
|
168
|
+
thread.start()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def server(cfg: Config, conns: list[Connection], device: str, dtype: torch.dtype):
|
|
172
|
+
r"""Runs a server in the main process that creates a policy and listens for observations from clients"""
|
|
173
|
+
init_proc_logging(None, cfg)
|
|
174
|
+
logging.info(pformat(asdict(cfg)))
|
|
175
|
+
|
|
176
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
177
|
+
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
|
|
178
|
+
policy.to(device=device, dtype=dtype)
|
|
179
|
+
policy.eval()
|
|
180
|
+
|
|
181
|
+
connection_buffer = ConnectionBuffer(
|
|
182
|
+
conns,
|
|
183
|
+
max_wait_sec=cfg.max_wait_sec,
|
|
184
|
+
max_batch_size=cfg.batch_size,
|
|
185
|
+
device=device,
|
|
186
|
+
dtype=dtype,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
with torch.inference_mode():
|
|
190
|
+
for batch_inputs, resp_conns in connection_buffer.get_batch():
|
|
191
|
+
if not batch_inputs:
|
|
192
|
+
logging.debug("Got empty batch, continuing.")
|
|
193
|
+
continue
|
|
194
|
+
logging.debug(f"Received batch of size {len(batch_inputs)}")
|
|
195
|
+
batch_inputs = default_collate(batch_inputs)
|
|
196
|
+
# We return the entire action chunk and let the simulation process handle the caching.
|
|
197
|
+
batch_chunked_actions = policy.sample_actions(batch_inputs)
|
|
198
|
+
batch_chunked_actions = rearrange(
|
|
199
|
+
batch_chunked_actions, "chunk batch action -> batch chunk action"
|
|
200
|
+
)
|
|
201
|
+
batch_chunked_actions = batch_chunked_actions.numpy(force=True)
|
|
202
|
+
batch_chunked_actions = batch_chunked_actions[:, : cfg.libero.chunk_usage, :LIBERO_ACTION_DIM]
|
|
203
|
+
# gripper open/close should be -1 or 1
|
|
204
|
+
batch_chunked_actions[:, :, -1] = 2.0 * (batch_chunked_actions[:, :, -1] > 0) - 1.0
|
|
205
|
+
|
|
206
|
+
for chunked_actions, conn in zip(batch_chunked_actions, resp_conns, strict=True):
|
|
207
|
+
resp = Response(chunked_action=chunked_actions)
|
|
208
|
+
logging.debug(f"sending action of shape {resp.chunked_action.shape} to simulation")
|
|
209
|
+
conn.send(resp)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def simulation(worker_id: int, cfg: Config, job_q: SimpleQueue, results_arr: Array, conn: Connection):
|
|
213
|
+
r"""Runs a simulation in a separate process. Sends observations to the server and receives actions."""
|
|
214
|
+
init_proc_logging(worker_id, cfg)
|
|
215
|
+
start_parent_check_thread()
|
|
216
|
+
|
|
217
|
+
# Patch gym before importing OffScreenRenderEnv at the start of the sim process.
|
|
218
|
+
gym_is_gymnasium_patch()
|
|
219
|
+
from libero.libero.envs import OffScreenRenderEnv
|
|
220
|
+
|
|
221
|
+
init_states = cfg.libero.init_states
|
|
222
|
+
while True:
|
|
223
|
+
sim_id = job_q.get()
|
|
224
|
+
if sim_id == SENTINEL:
|
|
225
|
+
logging.debug(f"Simulation process {os.getpid()} received SENTINEL, exiting.")
|
|
226
|
+
conn.send(SENTINEL)
|
|
227
|
+
conn.close()
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
# This environment provides interaction with the policy without rendering a UI.
|
|
231
|
+
# To record videos, we use the `LiberoObservationRecorder` class and manually record frames.
|
|
232
|
+
env = OffScreenRenderEnv(
|
|
233
|
+
bddl_file_name=cfg.libero.bddl_file,
|
|
234
|
+
camera_heights=cfg.resolution[0],
|
|
235
|
+
camera_widths=cfg.resolution[1],
|
|
236
|
+
)
|
|
237
|
+
env.seed(sim_id)
|
|
238
|
+
env.set_init_state(init_states[sim_id % len(init_states)])
|
|
239
|
+
video_root = cfg.libero.video_dir and (
|
|
240
|
+
Path(cfg.libero.video_dir) / cfg.libero.suite / str(cfg.libero.id) / str(sim_id)
|
|
241
|
+
)
|
|
242
|
+
camera_names = ["agentview_image", "robot0_eye_in_hand_image"]
|
|
243
|
+
with LiberoObservationRecorder(video_root, camera_names=camera_names) as recorder:
|
|
244
|
+
obs = env.reset()
|
|
245
|
+
# Warm up the environment with a few no-op steps
|
|
246
|
+
for _ in range(5):
|
|
247
|
+
obs, *_ = env.step([0.0] * LIBERO_ACTION_DIM)
|
|
248
|
+
recorder.record(obs)
|
|
249
|
+
action_cache = []
|
|
250
|
+
|
|
251
|
+
finish_step = -1
|
|
252
|
+
for step_id in range(1, cfg.libero.max_steps + 1):
|
|
253
|
+
if len(action_cache) == 0:
|
|
254
|
+
req = Request(sim_id=sim_id, step_id=step_id, observation=libero2np(obs, cfg))
|
|
255
|
+
logging.debug(f"Sending observation at step {step_id}")
|
|
256
|
+
conn.send(req)
|
|
257
|
+
resp = conn.recv()
|
|
258
|
+
logging.debug(f"Received action chunk with shape: {resp.chunked_action.shape}")
|
|
259
|
+
action_cache = deque(resp.chunked_action)
|
|
260
|
+
|
|
261
|
+
action = action_cache.popleft()
|
|
262
|
+
obs, reward, done, info = env.step(action)
|
|
263
|
+
recorder.record(obs)
|
|
264
|
+
|
|
265
|
+
logging.debug(f"Step: {step_id}, Reward: {reward}, Done: {done}, Info: {info}")
|
|
266
|
+
if done or reward > 0:
|
|
267
|
+
finish_step = step_id
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
logging.info(f"Result is {finish_step=}")
|
|
271
|
+
|
|
272
|
+
if sim_id > len(results_arr):
|
|
273
|
+
# Should never happen
|
|
274
|
+
logging.error(f"sim_id {sim_id} exceeds results array size {len(results_arr)}")
|
|
275
|
+
|
|
276
|
+
results_arr[sim_id] = finish_step
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def init_proc_logging(worker_id: int | None, cfg: Config):
|
|
280
|
+
r"""Initialize logging for server or worker processes."""
|
|
281
|
+
handlers = [
|
|
282
|
+
logging.StreamHandler(sys.stdout),
|
|
283
|
+
]
|
|
284
|
+
|
|
285
|
+
if cfg.logging_dir is not None:
|
|
286
|
+
filename = f"worker_{worker_id:03d}.log" if worker_id is not None else "server.log"
|
|
287
|
+
directory = Path(cfg.logging_dir)
|
|
288
|
+
directory.mkdir(parents=True, exist_ok=True)
|
|
289
|
+
handlers.append(logging.FileHandler(directory / filename))
|
|
290
|
+
|
|
291
|
+
prefix = "SERVER" if worker_id is None else f"WORKER-{worker_id:03d}"
|
|
292
|
+
logging.basicConfig(
|
|
293
|
+
level=logging.DEBUG if cfg.debug else logging.INFO,
|
|
294
|
+
format=f"{prefix}: %(asctime)s %(levelname)s %(message)s",
|
|
295
|
+
handlers=handlers,
|
|
296
|
+
force=True,
|
|
297
|
+
)
|
|
298
|
+
logging.info(f"Initialized in process {os.getpid()} by parent {os.getppid()}")
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@parser.wrap()
|
|
302
|
+
def main(cfg: Config):
|
|
303
|
+
device = auto_torch_device()
|
|
304
|
+
dtype = torch.bfloat16
|
|
305
|
+
|
|
306
|
+
if cfg.seed is not None:
|
|
307
|
+
set_seed(cfg.seed)
|
|
308
|
+
|
|
309
|
+
# job queue contains simulation IDs to be processed, and `SENTINEL`s to signal completion
|
|
310
|
+
job_queue = SimpleQueue()
|
|
311
|
+
for sim_id in range(cfg.libero.n_simulations):
|
|
312
|
+
job_queue.put(sim_id)
|
|
313
|
+
for _ in range(cfg.parallel_simulation_count):
|
|
314
|
+
job_queue.put(SENTINEL)
|
|
315
|
+
|
|
316
|
+
# Shared memory mapping for results. Since each simulation is only handled by one process, no lock is needed.
|
|
317
|
+
# -2 indicates uninitialized, -1 indicates failure to complete the task.
|
|
318
|
+
results_arr = Array(ctypes.c_int64, [-2] * cfg.libero.n_simulations, lock=False)
|
|
319
|
+
|
|
320
|
+
sim_procs, conns = [], []
|
|
321
|
+
for worker_id in range(cfg.parallel_simulation_count):
|
|
322
|
+
server_conn, client_conn = Pipe()
|
|
323
|
+
conns.append(server_conn)
|
|
324
|
+
|
|
325
|
+
# TODO ensure p is killed if the main process is killed
|
|
326
|
+
# TODO ensure that when p is killed, the client_conn is closed
|
|
327
|
+
p = Process(
|
|
328
|
+
target=simulation,
|
|
329
|
+
args=(
|
|
330
|
+
worker_id,
|
|
331
|
+
cfg, # cfg must be unpickle-able in sub-processes
|
|
332
|
+
job_queue,
|
|
333
|
+
results_arr,
|
|
334
|
+
client_conn,
|
|
335
|
+
),
|
|
336
|
+
)
|
|
337
|
+
sim_procs.append(p)
|
|
338
|
+
|
|
339
|
+
p.start() # Start the process before closing the client connection
|
|
340
|
+
client_conn.close()
|
|
341
|
+
|
|
342
|
+
server(cfg, conns, device, dtype)
|
|
343
|
+
|
|
344
|
+
logging.debug("Joining simulation processes...")
|
|
345
|
+
for p in sim_procs:
|
|
346
|
+
p.join()
|
|
347
|
+
|
|
348
|
+
logging.debug("All simulations completed. Gathering results...")
|
|
349
|
+
summary = summarize_libero_results(results_arr[:])
|
|
350
|
+
logging.info(str(summary))
|
|
351
|
+
for k, v in summary.items():
|
|
352
|
+
print(k, v)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
if __name__ == "__main__":
|
|
356
|
+
main()
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from dataclasses import asdict
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from pprint import pformat
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch.utils.data._utils.collate import default_collate
|
|
22
|
+
|
|
23
|
+
from opentau.configs import parser
|
|
24
|
+
from opentau.configs.libero import TrainConfigWithLiberoEval
|
|
25
|
+
from opentau.policies.factory import get_policy_class
|
|
26
|
+
from opentau.policies.pretrained import PreTrainedPolicy
|
|
27
|
+
from opentau.utils.libero import LiberoObservationRecorder, libero2torch, summarize_libero_results
|
|
28
|
+
from opentau.utils.monkey_patch import gym_is_gymnasium_patch
|
|
29
|
+
from opentau.utils.random_utils import set_seed
|
|
30
|
+
from opentau.utils.utils import auto_torch_device, init_logging
|
|
31
|
+
|
|
32
|
+
LIBERO_ACTION_DIM = 7
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def run_simulations(
|
|
36
|
+
policy: PreTrainedPolicy, cfg: TrainConfigWithLiberoEval, device: str, dtype: torch.dtype
|
|
37
|
+
):
|
|
38
|
+
gym_is_gymnasium_patch()
|
|
39
|
+
# This import has to happen after the `gym_is_gymnasium_patch` is called,
|
|
40
|
+
# so we can't put it at the top of the file.
|
|
41
|
+
from libero.libero.envs import OffScreenRenderEnv
|
|
42
|
+
|
|
43
|
+
init_states = cfg.libero.init_states
|
|
44
|
+
|
|
45
|
+
steps_taken = {}
|
|
46
|
+
for sim_idx in range(1, cfg.libero.n_simulations + 1):
|
|
47
|
+
# This environment provides interaction with the policy without rendering a UI.
|
|
48
|
+
# To record videos, we use the `LiberoObservationRecorder` class and manually record frames.
|
|
49
|
+
env = OffScreenRenderEnv(
|
|
50
|
+
bddl_file_name=cfg.libero.bddl_file,
|
|
51
|
+
camera_heights=cfg.resolution[0],
|
|
52
|
+
camera_widths=cfg.resolution[1],
|
|
53
|
+
)
|
|
54
|
+
s0 = init_states[sim_idx % len(init_states)]
|
|
55
|
+
env.seed(sim_idx)
|
|
56
|
+
env.set_init_state(s0)
|
|
57
|
+
|
|
58
|
+
video_root = cfg.libero.video_dir and (
|
|
59
|
+
Path(cfg.libero.video_dir) / cfg.libero.suite / str(cfg.libero.id) / str(sim_idx)
|
|
60
|
+
)
|
|
61
|
+
camera_names = ["agentview_image", "robot0_eye_in_hand_image"]
|
|
62
|
+
with LiberoObservationRecorder(video_root, camera_names=camera_names) as recorder:
|
|
63
|
+
obs = env.reset()
|
|
64
|
+
# Warm up the environment with a few no-op steps
|
|
65
|
+
for _ in range(5):
|
|
66
|
+
obs, *_ = env.step([0.0] * LIBERO_ACTION_DIM)
|
|
67
|
+
recorder.record(obs)
|
|
68
|
+
|
|
69
|
+
for step_idx in range(cfg.libero.max_steps):
|
|
70
|
+
if step_idx % cfg.libero.chunk_usage == 0:
|
|
71
|
+
logging.debug(f"Resetting policy before step {step_idx + 1} for simulation {sim_idx}")
|
|
72
|
+
# Invalidate the cache and force the policy to recompute a new batch of actions
|
|
73
|
+
policy.reset()
|
|
74
|
+
|
|
75
|
+
torch_input = libero2torch(obs, cfg, device, dtype)
|
|
76
|
+
torch_input = default_collate([torch_input])
|
|
77
|
+
action = policy.select_action(torch_input)
|
|
78
|
+
action = action.flatten().numpy(force=True)[:LIBERO_ACTION_DIM]
|
|
79
|
+
action[-1] = 2.0 * (action[-1] > 0) - 1.0 # gripper open/close should be -1 or 1
|
|
80
|
+
obs, reward, done, info = env.step(action)
|
|
81
|
+
recorder.record(obs)
|
|
82
|
+
logging.debug(f"Step: {step_idx + 1}, Reward: {reward}, Done: {done}, Info: {info}")
|
|
83
|
+
if done or reward > 0:
|
|
84
|
+
steps_taken[sim_idx] = step_idx + 1
|
|
85
|
+
break
|
|
86
|
+
|
|
87
|
+
env.close()
|
|
88
|
+
|
|
89
|
+
return steps_taken
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@parser.wrap()
|
|
93
|
+
def main(cfg: TrainConfigWithLiberoEval):
|
|
94
|
+
init_logging(level=logging.DEBUG if cfg.debug else logging.INFO)
|
|
95
|
+
logging.info(pformat(asdict(cfg)))
|
|
96
|
+
|
|
97
|
+
device = auto_torch_device()
|
|
98
|
+
dtype = torch.bfloat16
|
|
99
|
+
|
|
100
|
+
if cfg.seed is not None:
|
|
101
|
+
set_seed(cfg.seed)
|
|
102
|
+
|
|
103
|
+
logging.info("Creating policy")
|
|
104
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
105
|
+
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
|
|
106
|
+
policy.to(device=device, dtype=torch.bfloat16)
|
|
107
|
+
policy.eval()
|
|
108
|
+
|
|
109
|
+
with torch.inference_mode():
|
|
110
|
+
steps_taken = run_simulations(policy, cfg, device, dtype)
|
|
111
|
+
|
|
112
|
+
results = [-1] * cfg.libero.n_simulations
|
|
113
|
+
for sim_idx, step in steps_taken.items():
|
|
114
|
+
results[sim_idx - 1] = step
|
|
115
|
+
summary = summarize_libero_results(results)
|
|
116
|
+
logging.info(str(summary))
|
|
117
|
+
for k, v in summary.items():
|
|
118
|
+
print(k, v)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
if __name__ == "__main__":
|
|
122
|
+
main()
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
from dotenv import load_dotenv
|
|
20
|
+
from PIL import Image
|
|
21
|
+
|
|
22
|
+
from opentau.planner import NavHighLevelPlanner
|
|
23
|
+
from opentau.utils.utils import (
|
|
24
|
+
init_logging,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
load_dotenv()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def main(img_dir_path):
|
|
31
|
+
frames = sorted(os.listdir(img_dir_path))
|
|
32
|
+
logging.info("Loading the frames")
|
|
33
|
+
img_dict1 = {}
|
|
34
|
+
for i, image_path in enumerate(frames):
|
|
35
|
+
img = Image.open(img_dir_path + "/" + image_path).convert("RGB")
|
|
36
|
+
img_dict1[i] = img
|
|
37
|
+
|
|
38
|
+
# dummy instructions
|
|
39
|
+
task = "The goal is to reach till fridge"
|
|
40
|
+
nav_planner = NavHighLevelPlanner()
|
|
41
|
+
logging.info("Inferencing the navigational planner")
|
|
42
|
+
actions = nav_planner.inference(image_dict=img_dict1, model_name="gpt4o", task=task, mem=None)
|
|
43
|
+
|
|
44
|
+
logging.info(f"The instructions are {actions}")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
parser = argparse.ArgumentParser(
|
|
49
|
+
description="Run the navigation high level planner with a specified image directory."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# 2. Add the --img_path argument
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--img_path", type=str, required=True, help="Path to the directory containing the image frames."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# 3. Parse the arguments from the command line
|
|
58
|
+
args = parser.parse_args()
|
|
59
|
+
|
|
60
|
+
init_logging()
|
|
61
|
+
main(args.img_path)
|