opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,63 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import subprocess
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ import opentau.scripts.train as train_script
21
+
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser(
25
+ description="Launch OpenTau training with Accelerate",
26
+ usage="opentau-train [--accelerate-config CONFIG] [TRAINING_ARGS]",
27
+ )
28
+ parser.add_argument(
29
+ "--accelerate-config", type=str, help="Path to accelerate config file (yaml)", default=None
30
+ )
31
+ # We use parse_known_args so that all other arguments are collected
32
+ # These will be passed to the training script
33
+ args, unknown_args = parser.parse_known_args()
34
+
35
+ # Base command
36
+ cmd = ["accelerate", "launch"]
37
+
38
+ # Add accelerate config if provided
39
+ if args.accelerate_config:
40
+ cmd.extend(["--config_file", args.accelerate_config])
41
+
42
+ # Add the path to the training script
43
+ # We resolve the path to ensure it's absolute
44
+ train_script_path = Path(train_script.__file__).resolve()
45
+ cmd.append(str(train_script_path))
46
+
47
+ # Add all other arguments (passed to the training script)
48
+ cmd.extend(unknown_args)
49
+
50
+ # Print the command for transparency
51
+ print(f"Executing: {' '.join(cmd)}")
52
+
53
+ # Replace the current process with the accelerate launch command
54
+ try:
55
+ subprocess.run(cmd, check=True)
56
+ except subprocess.CalledProcessError as e:
57
+ sys.exit(e.returncode)
58
+ except KeyboardInterrupt:
59
+ sys.exit(130)
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
@@ -0,0 +1,356 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ctypes
16
+ import logging
17
+ import os
18
+ import signal
19
+ import sys
20
+ import threading
21
+ import time
22
+ from collections import deque
23
+ from dataclasses import asdict, dataclass
24
+ from multiprocessing import Array, Pipe, Process, SimpleQueue
25
+ from multiprocessing.connection import Connection, wait
26
+ from pathlib import Path
27
+ from pprint import pformat
28
+
29
+ import numpy as np
30
+ import psutil
31
+ import torch
32
+ from einops import rearrange
33
+ from torch.utils.data._utils.collate import default_collate
34
+
35
+ from opentau.configs import parser
36
+ from opentau.configs.libero import TrainConfigWithLiberoEval
37
+ from opentau.policies.factory import get_policy_class
38
+ from opentau.utils.libero import LiberoObservationRecorder, summarize_libero_results
39
+ from opentau.utils.libero import _libero2np as libero2np
40
+ from opentau.utils.libero import _np2torch as np2torch
41
+ from opentau.utils.monkey_patch import gym_is_gymnasium_patch
42
+ from opentau.utils.random_utils import set_seed
43
+ from opentau.utils.utils import auto_torch_device
44
+
45
+ # Sent by client process to indicate simulation completion and signal that the pipe is to be closed
46
+ SENTINEL = "<SENTINEL>"
47
+
48
+ LIBERO_ACTION_DIM = 7
49
+
50
+
51
+ @dataclass
52
+ class Config(TrainConfigWithLiberoEval):
53
+ parallel_simulation_count: int = 4
54
+ max_wait_sec: float = 1.0
55
+ logging_dir: str | None = None
56
+
57
+
58
+ @dataclass
59
+ class Request:
60
+ r"""Request sent from the CPU LIBERO simulation process to the GPU policy."""
61
+
62
+ sim_id: int
63
+ step_id: int
64
+ observation: dict[str, np.ndarray | str]
65
+
66
+
67
+ @dataclass
68
+ class Response:
69
+ r"""Response sent from the GPU policy to the CPU LIBERO simulation process."""
70
+
71
+ chunked_action: np.ndarray
72
+
73
+
74
+ class ConnectionBuffer:
75
+ def __init__(
76
+ self,
77
+ conns: list[Connection],
78
+ max_wait_sec: float,
79
+ max_batch_size: int,
80
+ device: str,
81
+ dtype: torch.dtype,
82
+ ):
83
+ r"""Gathers a batch of inputs. Wait for no more than `max_wait_time` seconds,
84
+ or until `max_batch_size` is reached."""
85
+ self.conns = conns
86
+ self.max_wait = max_wait_sec
87
+ self.max_batch = max_batch_size
88
+ self.device = device
89
+ self.dtype = dtype
90
+ self.batch_inputs = []
91
+ self.response_list = []
92
+ self.last_yield_time = None
93
+
94
+ def _should_yield(self):
95
+ # Don't yield empty batches
96
+ if not self.batch_inputs:
97
+ return False
98
+
99
+ return (
100
+ len(self.batch_inputs) >= self.max_batch
101
+ or time.monotonic() - self.last_yield_time >= self.max_wait
102
+ or not self.conns
103
+ )
104
+
105
+ def get_batch(self):
106
+ self.last_yield_time = time.monotonic()
107
+
108
+ while self.conns or self.batch_inputs:
109
+ timeout = self.last_yield_time + self.max_wait - time.monotonic()
110
+ selected = wait(self.conns, timeout=max(timeout, 0.0)) if self.conns else []
111
+ for ready in selected:
112
+ try:
113
+ req = ready.recv()
114
+ if req != SENTINEL:
115
+ xs = np2torch(req.observation, self.device, self.dtype)
116
+ except Exception as e: # In case the simulation process crashed
117
+ logging.error(str(e))
118
+ req = SENTINEL
119
+
120
+ if req == SENTINEL:
121
+ logging.debug("Removing connection")
122
+ self.conns.remove(ready)
123
+ ready.close()
124
+ continue
125
+
126
+ logging.debug(f"Received a request from sim {req.sim_id} at step {req.step_id}")
127
+
128
+ self.batch_inputs.append(xs)
129
+ self.response_list.append(ready)
130
+ if self._should_yield():
131
+ break
132
+
133
+ if self._should_yield():
134
+ bi, br = self.batch_inputs, self.response_list
135
+ self.batch_inputs, self.response_list = [], []
136
+ self.last_yield_time = time.monotonic()
137
+ yield bi, br
138
+
139
+
140
+ def start_parent_check_thread():
141
+ def is_process_active(pid):
142
+ try:
143
+ process = psutil.Process(pid)
144
+ return process.is_running() and process.status() != psutil.STATUS_ZOMBIE
145
+ except psutil.NoSuchProcess:
146
+ return False
147
+
148
+ def kill_child_processes(parent_pid):
149
+ parent = psutil.Process(parent_pid)
150
+ for child in parent.children(recursive=True):
151
+ try:
152
+ os.kill(child.pid, signal.SIGKILL)
153
+ logging.warning(f"Killed pid {child.pid}")
154
+ except BaseException as e:
155
+ logging.warning(f"Killing pid {child.pid} failed {str(e)}")
156
+
157
+ def check_parent_alive():
158
+ parent_pid = os.getppid()
159
+ while True:
160
+ if not is_process_active(parent_pid):
161
+ logging.warning(f"Parent is dead, kill self {os.getpid()}")
162
+ kill_child_processes(os.getpid())
163
+ os.kill(os.getpid(), signal.SIGKILL)
164
+
165
+ time.sleep(10)
166
+
167
+ thread = threading.Thread(target=check_parent_alive, daemon=True)
168
+ thread.start()
169
+
170
+
171
+ def server(cfg: Config, conns: list[Connection], device: str, dtype: torch.dtype):
172
+ r"""Runs a server in the main process that creates a policy and listens for observations from clients"""
173
+ init_proc_logging(None, cfg)
174
+ logging.info(pformat(asdict(cfg)))
175
+
176
+ policy_class = get_policy_class(cfg.policy.type)
177
+ policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
178
+ policy.to(device=device, dtype=dtype)
179
+ policy.eval()
180
+
181
+ connection_buffer = ConnectionBuffer(
182
+ conns,
183
+ max_wait_sec=cfg.max_wait_sec,
184
+ max_batch_size=cfg.batch_size,
185
+ device=device,
186
+ dtype=dtype,
187
+ )
188
+
189
+ with torch.inference_mode():
190
+ for batch_inputs, resp_conns in connection_buffer.get_batch():
191
+ if not batch_inputs:
192
+ logging.debug("Got empty batch, continuing.")
193
+ continue
194
+ logging.debug(f"Received batch of size {len(batch_inputs)}")
195
+ batch_inputs = default_collate(batch_inputs)
196
+ # We return the entire action chunk and let the simulation process handle the caching.
197
+ batch_chunked_actions = policy.sample_actions(batch_inputs)
198
+ batch_chunked_actions = rearrange(
199
+ batch_chunked_actions, "chunk batch action -> batch chunk action"
200
+ )
201
+ batch_chunked_actions = batch_chunked_actions.numpy(force=True)
202
+ batch_chunked_actions = batch_chunked_actions[:, : cfg.libero.chunk_usage, :LIBERO_ACTION_DIM]
203
+ # gripper open/close should be -1 or 1
204
+ batch_chunked_actions[:, :, -1] = 2.0 * (batch_chunked_actions[:, :, -1] > 0) - 1.0
205
+
206
+ for chunked_actions, conn in zip(batch_chunked_actions, resp_conns, strict=True):
207
+ resp = Response(chunked_action=chunked_actions)
208
+ logging.debug(f"sending action of shape {resp.chunked_action.shape} to simulation")
209
+ conn.send(resp)
210
+
211
+
212
+ def simulation(worker_id: int, cfg: Config, job_q: SimpleQueue, results_arr: Array, conn: Connection):
213
+ r"""Runs a simulation in a separate process. Sends observations to the server and receives actions."""
214
+ init_proc_logging(worker_id, cfg)
215
+ start_parent_check_thread()
216
+
217
+ # Patch gym before importing OffScreenRenderEnv at the start of the sim process.
218
+ gym_is_gymnasium_patch()
219
+ from libero.libero.envs import OffScreenRenderEnv
220
+
221
+ init_states = cfg.libero.init_states
222
+ while True:
223
+ sim_id = job_q.get()
224
+ if sim_id == SENTINEL:
225
+ logging.debug(f"Simulation process {os.getpid()} received SENTINEL, exiting.")
226
+ conn.send(SENTINEL)
227
+ conn.close()
228
+ return
229
+
230
+ # This environment provides interaction with the policy without rendering a UI.
231
+ # To record videos, we use the `LiberoObservationRecorder` class and manually record frames.
232
+ env = OffScreenRenderEnv(
233
+ bddl_file_name=cfg.libero.bddl_file,
234
+ camera_heights=cfg.resolution[0],
235
+ camera_widths=cfg.resolution[1],
236
+ )
237
+ env.seed(sim_id)
238
+ env.set_init_state(init_states[sim_id % len(init_states)])
239
+ video_root = cfg.libero.video_dir and (
240
+ Path(cfg.libero.video_dir) / cfg.libero.suite / str(cfg.libero.id) / str(sim_id)
241
+ )
242
+ camera_names = ["agentview_image", "robot0_eye_in_hand_image"]
243
+ with LiberoObservationRecorder(video_root, camera_names=camera_names) as recorder:
244
+ obs = env.reset()
245
+ # Warm up the environment with a few no-op steps
246
+ for _ in range(5):
247
+ obs, *_ = env.step([0.0] * LIBERO_ACTION_DIM)
248
+ recorder.record(obs)
249
+ action_cache = []
250
+
251
+ finish_step = -1
252
+ for step_id in range(1, cfg.libero.max_steps + 1):
253
+ if len(action_cache) == 0:
254
+ req = Request(sim_id=sim_id, step_id=step_id, observation=libero2np(obs, cfg))
255
+ logging.debug(f"Sending observation at step {step_id}")
256
+ conn.send(req)
257
+ resp = conn.recv()
258
+ logging.debug(f"Received action chunk with shape: {resp.chunked_action.shape}")
259
+ action_cache = deque(resp.chunked_action)
260
+
261
+ action = action_cache.popleft()
262
+ obs, reward, done, info = env.step(action)
263
+ recorder.record(obs)
264
+
265
+ logging.debug(f"Step: {step_id}, Reward: {reward}, Done: {done}, Info: {info}")
266
+ if done or reward > 0:
267
+ finish_step = step_id
268
+ break
269
+
270
+ logging.info(f"Result is {finish_step=}")
271
+
272
+ if sim_id > len(results_arr):
273
+ # Should never happen
274
+ logging.error(f"sim_id {sim_id} exceeds results array size {len(results_arr)}")
275
+
276
+ results_arr[sim_id] = finish_step
277
+
278
+
279
+ def init_proc_logging(worker_id: int | None, cfg: Config):
280
+ r"""Initialize logging for server or worker processes."""
281
+ handlers = [
282
+ logging.StreamHandler(sys.stdout),
283
+ ]
284
+
285
+ if cfg.logging_dir is not None:
286
+ filename = f"worker_{worker_id:03d}.log" if worker_id is not None else "server.log"
287
+ directory = Path(cfg.logging_dir)
288
+ directory.mkdir(parents=True, exist_ok=True)
289
+ handlers.append(logging.FileHandler(directory / filename))
290
+
291
+ prefix = "SERVER" if worker_id is None else f"WORKER-{worker_id:03d}"
292
+ logging.basicConfig(
293
+ level=logging.DEBUG if cfg.debug else logging.INFO,
294
+ format=f"{prefix}: %(asctime)s %(levelname)s %(message)s",
295
+ handlers=handlers,
296
+ force=True,
297
+ )
298
+ logging.info(f"Initialized in process {os.getpid()} by parent {os.getppid()}")
299
+
300
+
301
+ @parser.wrap()
302
+ def main(cfg: Config):
303
+ device = auto_torch_device()
304
+ dtype = torch.bfloat16
305
+
306
+ if cfg.seed is not None:
307
+ set_seed(cfg.seed)
308
+
309
+ # job queue contains simulation IDs to be processed, and `SENTINEL`s to signal completion
310
+ job_queue = SimpleQueue()
311
+ for sim_id in range(cfg.libero.n_simulations):
312
+ job_queue.put(sim_id)
313
+ for _ in range(cfg.parallel_simulation_count):
314
+ job_queue.put(SENTINEL)
315
+
316
+ # Shared memory mapping for results. Since each simulation is only handled by one process, no lock is needed.
317
+ # -2 indicates uninitialized, -1 indicates failure to complete the task.
318
+ results_arr = Array(ctypes.c_int64, [-2] * cfg.libero.n_simulations, lock=False)
319
+
320
+ sim_procs, conns = [], []
321
+ for worker_id in range(cfg.parallel_simulation_count):
322
+ server_conn, client_conn = Pipe()
323
+ conns.append(server_conn)
324
+
325
+ # TODO ensure p is killed if the main process is killed
326
+ # TODO ensure that when p is killed, the client_conn is closed
327
+ p = Process(
328
+ target=simulation,
329
+ args=(
330
+ worker_id,
331
+ cfg, # cfg must be unpickle-able in sub-processes
332
+ job_queue,
333
+ results_arr,
334
+ client_conn,
335
+ ),
336
+ )
337
+ sim_procs.append(p)
338
+
339
+ p.start() # Start the process before closing the client connection
340
+ client_conn.close()
341
+
342
+ server(cfg, conns, device, dtype)
343
+
344
+ logging.debug("Joining simulation processes...")
345
+ for p in sim_procs:
346
+ p.join()
347
+
348
+ logging.debug("All simulations completed. Gathering results...")
349
+ summary = summarize_libero_results(results_arr[:])
350
+ logging.info(str(summary))
351
+ for k, v in summary.items():
352
+ print(k, v)
353
+
354
+
355
+ if __name__ == "__main__":
356
+ main()
@@ -0,0 +1,122 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from dataclasses import asdict
17
+ from pathlib import Path
18
+ from pprint import pformat
19
+
20
+ import torch
21
+ from torch.utils.data._utils.collate import default_collate
22
+
23
+ from opentau.configs import parser
24
+ from opentau.configs.libero import TrainConfigWithLiberoEval
25
+ from opentau.policies.factory import get_policy_class
26
+ from opentau.policies.pretrained import PreTrainedPolicy
27
+ from opentau.utils.libero import LiberoObservationRecorder, libero2torch, summarize_libero_results
28
+ from opentau.utils.monkey_patch import gym_is_gymnasium_patch
29
+ from opentau.utils.random_utils import set_seed
30
+ from opentau.utils.utils import auto_torch_device, init_logging
31
+
32
+ LIBERO_ACTION_DIM = 7
33
+
34
+
35
+ def run_simulations(
36
+ policy: PreTrainedPolicy, cfg: TrainConfigWithLiberoEval, device: str, dtype: torch.dtype
37
+ ):
38
+ gym_is_gymnasium_patch()
39
+ # This import has to happen after the `gym_is_gymnasium_patch` is called,
40
+ # so we can't put it at the top of the file.
41
+ from libero.libero.envs import OffScreenRenderEnv
42
+
43
+ init_states = cfg.libero.init_states
44
+
45
+ steps_taken = {}
46
+ for sim_idx in range(1, cfg.libero.n_simulations + 1):
47
+ # This environment provides interaction with the policy without rendering a UI.
48
+ # To record videos, we use the `LiberoObservationRecorder` class and manually record frames.
49
+ env = OffScreenRenderEnv(
50
+ bddl_file_name=cfg.libero.bddl_file,
51
+ camera_heights=cfg.resolution[0],
52
+ camera_widths=cfg.resolution[1],
53
+ )
54
+ s0 = init_states[sim_idx % len(init_states)]
55
+ env.seed(sim_idx)
56
+ env.set_init_state(s0)
57
+
58
+ video_root = cfg.libero.video_dir and (
59
+ Path(cfg.libero.video_dir) / cfg.libero.suite / str(cfg.libero.id) / str(sim_idx)
60
+ )
61
+ camera_names = ["agentview_image", "robot0_eye_in_hand_image"]
62
+ with LiberoObservationRecorder(video_root, camera_names=camera_names) as recorder:
63
+ obs = env.reset()
64
+ # Warm up the environment with a few no-op steps
65
+ for _ in range(5):
66
+ obs, *_ = env.step([0.0] * LIBERO_ACTION_DIM)
67
+ recorder.record(obs)
68
+
69
+ for step_idx in range(cfg.libero.max_steps):
70
+ if step_idx % cfg.libero.chunk_usage == 0:
71
+ logging.debug(f"Resetting policy before step {step_idx + 1} for simulation {sim_idx}")
72
+ # Invalidate the cache and force the policy to recompute a new batch of actions
73
+ policy.reset()
74
+
75
+ torch_input = libero2torch(obs, cfg, device, dtype)
76
+ torch_input = default_collate([torch_input])
77
+ action = policy.select_action(torch_input)
78
+ action = action.flatten().numpy(force=True)[:LIBERO_ACTION_DIM]
79
+ action[-1] = 2.0 * (action[-1] > 0) - 1.0 # gripper open/close should be -1 or 1
80
+ obs, reward, done, info = env.step(action)
81
+ recorder.record(obs)
82
+ logging.debug(f"Step: {step_idx + 1}, Reward: {reward}, Done: {done}, Info: {info}")
83
+ if done or reward > 0:
84
+ steps_taken[sim_idx] = step_idx + 1
85
+ break
86
+
87
+ env.close()
88
+
89
+ return steps_taken
90
+
91
+
92
+ @parser.wrap()
93
+ def main(cfg: TrainConfigWithLiberoEval):
94
+ init_logging(level=logging.DEBUG if cfg.debug else logging.INFO)
95
+ logging.info(pformat(asdict(cfg)))
96
+
97
+ device = auto_torch_device()
98
+ dtype = torch.bfloat16
99
+
100
+ if cfg.seed is not None:
101
+ set_seed(cfg.seed)
102
+
103
+ logging.info("Creating policy")
104
+ policy_class = get_policy_class(cfg.policy.type)
105
+ policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
106
+ policy.to(device=device, dtype=torch.bfloat16)
107
+ policy.eval()
108
+
109
+ with torch.inference_mode():
110
+ steps_taken = run_simulations(policy, cfg, device, dtype)
111
+
112
+ results = [-1] * cfg.libero.n_simulations
113
+ for sim_idx, step in steps_taken.items():
114
+ results[sim_idx - 1] = step
115
+ summary = summarize_libero_results(results)
116
+ logging.info(str(summary))
117
+ for k, v in summary.items():
118
+ print(k, v)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()
@@ -0,0 +1,61 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import logging
17
+ import os
18
+
19
+ from dotenv import load_dotenv
20
+ from PIL import Image
21
+
22
+ from opentau.planner import NavHighLevelPlanner
23
+ from opentau.utils.utils import (
24
+ init_logging,
25
+ )
26
+
27
+ load_dotenv()
28
+
29
+
30
+ def main(img_dir_path):
31
+ frames = sorted(os.listdir(img_dir_path))
32
+ logging.info("Loading the frames")
33
+ img_dict1 = {}
34
+ for i, image_path in enumerate(frames):
35
+ img = Image.open(img_dir_path + "/" + image_path).convert("RGB")
36
+ img_dict1[i] = img
37
+
38
+ # dummy instructions
39
+ task = "The goal is to reach till fridge"
40
+ nav_planner = NavHighLevelPlanner()
41
+ logging.info("Inferencing the navigational planner")
42
+ actions = nav_planner.inference(image_dict=img_dict1, model_name="gpt4o", task=task, mem=None)
43
+
44
+ logging.info(f"The instructions are {actions}")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ parser = argparse.ArgumentParser(
49
+ description="Run the navigation high level planner with a specified image directory."
50
+ )
51
+
52
+ # 2. Add the --img_path argument
53
+ parser.add_argument(
54
+ "--img_path", type=str, required=True, help="Path to the directory containing the image frames."
55
+ )
56
+
57
+ # 3. Parse the arguments from the command line
58
+ args = parser.parse_args()
59
+
60
+ init_logging()
61
+ main(args.img_path)