opentau 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opentau/scripts/train.py CHANGED
@@ -17,6 +17,9 @@
17
17
  import json
18
18
  import logging
19
19
  import os
20
+
21
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
+
20
23
  from contextlib import nullcontext
21
24
  from pprint import pformat
22
25
  from typing import Any
@@ -176,7 +179,10 @@ def train(cfg: TrainPipelineConfig):
176
179
  logging.info("Anomaly detection is disabled.")
177
180
 
178
181
  logging.info("Creating dataset")
179
- dataset = make_dataset_mixture(cfg)
182
+ if cfg.val_freq > 0:
183
+ train_dataset, val_dataset = make_dataset_mixture(cfg)
184
+ else:
185
+ train_dataset = make_dataset_mixture(cfg)
180
186
 
181
187
  # Create environment used for evaluating checkpoints during training on simulation data.
182
188
  # On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -188,7 +194,7 @@ def train(cfg: TrainPipelineConfig):
188
194
  )
189
195
 
190
196
  logging.info("Creating policy")
191
- policy = make_policy(cfg=cfg.policy, ds_meta=dataset.meta)
197
+ policy = make_policy(cfg=cfg.policy, ds_meta=train_dataset.meta)
192
198
  policy.to(torch.bfloat16)
193
199
  logging.info("Creating optimizer and scheduler")
194
200
  optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@@ -203,11 +209,18 @@ def train(cfg: TrainPipelineConfig):
203
209
  logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
204
210
  logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
205
211
 
206
- dataloader = dataset.get_dataloader()
207
- policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
208
- policy, optimizer, dataloader, lr_scheduler
209
- )
210
- dl_iter = cycle(dataloader)
212
+ if cfg.val_freq > 0:
213
+ train_dataloader = train_dataset.get_dataloader()
214
+ val_dataloader = val_dataset.get_dataloader()
215
+ policy, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
216
+ policy, optimizer, train_dataloader, val_dataloader, lr_scheduler
217
+ )
218
+ else:
219
+ train_dataloader = train_dataset.get_dataloader()
220
+ policy, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
221
+ policy, optimizer, train_dataloader, lr_scheduler
222
+ )
223
+ train_dl_iter = cycle(train_dataloader)
211
224
 
212
225
  # Register the LR scheduler for checkpointing
213
226
  accelerator.register_for_checkpointing(lr_scheduler)
@@ -246,7 +259,7 @@ def train(cfg: TrainPipelineConfig):
246
259
  for _ in range(cfg.gradient_accumulation_steps):
247
260
  with accelerator.accumulate(policy) if cfg.gradient_accumulation_steps > 1 else nullcontext():
248
261
  logging.debug(f"{step=}, {accelerator.sync_gradients=}")
249
- batch = next(dl_iter)
262
+ batch = next(train_dl_iter)
250
263
 
251
264
  train_tracker = update_policy(
252
265
  cfg,
@@ -266,20 +279,21 @@ def train(cfg: TrainPipelineConfig):
266
279
  is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
267
280
  is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and cfg.save_checkpoint
268
281
  is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
282
+ is_val_step = cfg.val_freq > 0 and step % cfg.val_freq == 0
269
283
 
270
284
  # Only `train_tracker` on the main process keeps useful statistics,
271
285
  # because we guarded it with if accelerator.is_main_process in the `update_policy` function.
272
286
  if is_log_step and accelerator.is_main_process:
273
287
  logging.info(train_tracker)
274
288
  log_dict = train_tracker.to_dict(use_avg=True)
275
- accelerator.log({"Training Loss": log_dict["loss"]}, step=step)
276
- accelerator.log({"MSE Loss": log_dict["mse_loss"]}, step=step)
277
- accelerator.log({"CE Loss": log_dict["ce_loss"]}, step=step)
278
- accelerator.log({"L1 Loss": log_dict["l1_loss"]}, step=step)
279
- accelerator.log({"Accuracy": log_dict["accuracy"]}, step=step)
280
- accelerator.log({"Learning Rate": log_dict["lr"]}, step=step)
281
- accelerator.log({"Grad Norm": log_dict["grad_norm"]}, step=step)
282
- accelerator.log({"Num Samples": log_dict["samples"]}, step=step)
289
+ accelerator.log({"Training/Loss": log_dict["loss"]}, step=step)
290
+ accelerator.log({"Training/MSE Loss": log_dict["mse_loss"]}, step=step)
291
+ accelerator.log({"Training/CE Loss": log_dict["ce_loss"]}, step=step)
292
+ accelerator.log({"Training/L1 Loss": log_dict["l1_loss"]}, step=step)
293
+ accelerator.log({"Training/Accuracy": log_dict["accuracy"]}, step=step)
294
+ accelerator.log({"Training/Learning Rate": log_dict["lr"]}, step=step)
295
+ accelerator.log({"Training/Grad Norm": log_dict["grad_norm"]}, step=step)
296
+ accelerator.log({"Training/Num Samples": log_dict["samples"]}, step=step)
283
297
  train_tracker.reset_averages()
284
298
 
285
299
  if is_saving_step:
@@ -299,6 +313,70 @@ def train(cfg: TrainPipelineConfig):
299
313
  if cfg.last_checkpoint_only:
300
314
  prune_old_checkpoints(checkpoint_dir)
301
315
 
316
+ accelerator.wait_for_everyone()
317
+
318
+ if is_val_step:
319
+ policy.eval()
320
+ val_metrics = {
321
+ "loss": AverageMeter("val_total_loss", ":.3f"),
322
+ "mse_loss": AverageMeter("val_mse_loss", ":.3f"),
323
+ "ce_loss": AverageMeter("val_ce_loss", ":.3f"),
324
+ "l1_loss": AverageMeter("val_l1_loss", ":.3f"),
325
+ "accuracy": AverageMeter("val_accuracy", ":.3f"),
326
+ }
327
+ val_tracker = MetricsTracker(
328
+ cfg.batch_size * accelerator.num_processes,
329
+ val_metrics,
330
+ initial_step=step,
331
+ )
332
+
333
+ logging.info(f"Validation at step {step}...")
334
+
335
+ with torch.no_grad():
336
+ for batch in val_dataloader:
337
+ losses = policy.forward(batch)
338
+ loss = cfg.loss_weighting["MSE"] * losses["MSE"] + cfg.loss_weighting["CE"] * losses["CE"]
339
+
340
+ # Gather and average metrics across processes
341
+ _first_loss_tensor = next(lt for lt in losses.values() if isinstance(lt, torch.Tensor))
342
+ zero = torch.tensor(0.0, device=_first_loss_tensor.device, dtype=_first_loss_tensor.dtype)
343
+
344
+ loss = accelerator.gather_for_metrics(loss).mean().item()
345
+ mse_loss = (
346
+ accelerator.gather_for_metrics(losses["MSE"]).to(dtype=torch.float32).mean().item()
347
+ )
348
+ ce_loss = (
349
+ accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item()
350
+ )
351
+ l1_loss = (
352
+ accelerator.gather_for_metrics(losses.get("L1", zero))
353
+ .to(dtype=torch.float32)
354
+ .mean()
355
+ .item()
356
+ )
357
+ accuracy = (
358
+ accelerator.gather_for_metrics(losses.get("Accuracy", zero))
359
+ .to(dtype=torch.float32)
360
+ .mean()
361
+ .item()
362
+ )
363
+
364
+ if accelerator.is_main_process:
365
+ val_tracker.loss = loss
366
+ val_tracker.mse_loss = mse_loss
367
+ val_tracker.ce_loss = ce_loss
368
+ val_tracker.l1_loss = l1_loss
369
+ val_tracker.accuracy = accuracy
370
+
371
+ if accelerator.is_main_process:
372
+ logging.info(val_tracker)
373
+ val_dict = val_tracker.to_dict(use_avg=True)
374
+ accelerator.log({"Validation/Loss": val_dict["loss"]}, step=step)
375
+ accelerator.log({"Validation/MSE Loss": val_dict["mse_loss"]}, step=step)
376
+ accelerator.log({"Validation/CE Loss": val_dict["ce_loss"]}, step=step)
377
+ accelerator.log({"Validation/L1 Loss": val_dict["l1_loss"]}, step=step)
378
+ accelerator.log({"Validation/Accuracy": val_dict["accuracy"]}, step=step)
379
+
302
380
  # This barrier is probably necessary to ensure
303
381
  # other processes wait for the main process to finish saving
304
382
  accelerator.wait_for_everyone()
@@ -357,7 +435,6 @@ def train(cfg: TrainPipelineConfig):
357
435
  with open(videos_dir / "eval_info.json", "w") as f:
358
436
  json.dump(eval_info, f, indent=2)
359
437
 
360
- if is_eval_step:
361
438
  # This barrier is to ensure all processes finishes evaluation before the next training step
362
439
  # Some processes might be slower than others
363
440
  accelerator.wait_for_everyone()
@@ -52,7 +52,9 @@ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --mode di
52
52
  import argparse
53
53
  import gc
54
54
  import logging
55
+ import os
55
56
  import time
57
+ import warnings
56
58
  from pathlib import Path
57
59
  from typing import Iterator
58
60
 
@@ -66,6 +68,52 @@ from opentau.configs.default import DatasetMixtureConfig, WandBConfig
66
68
  from opentau.configs.train import TrainPipelineConfig
67
69
  from opentau.datasets.lerobot_dataset import LeRobotDataset
68
70
 
71
+ PERMIT_URDF = hasattr(rr, "urdf")
72
+ if not PERMIT_URDF:
73
+ warnings.warn(
74
+ "`rerun.urdf` module not found. Make sure you have rerun >= 0.28.2 installed. "
75
+ " One way to ensure this is to install OpenTau with the '[urdf]' extra: `pip install opentau[urdf]`.",
76
+ stacklevel=2,
77
+ )
78
+
79
+
80
+ # Older and newer versions of rerun have different APIs for setting time / sequence
81
+ def _rr_set_sequence(timeline: str, value: int):
82
+ if hasattr(rr, "set_time_sequence"):
83
+ rr.set_time_sequence(timeline, value)
84
+ else:
85
+ rr.set_time(timeline, sequence=value)
86
+
87
+
88
+ def _rr_set_seconds(timeline: str, value: float):
89
+ if hasattr(rr, "set_time_seconds"):
90
+ rr.set_time_seconds(timeline, value)
91
+ else:
92
+ rr.set_time(timeline, timestamp=value)
93
+
94
+
95
+ def _rr_scalar(value: float):
96
+ """Return a rerun scalar archetype that works across rerun versions.
97
+
98
+ Older rerun versions expose `rr.Scalar`, while newer versions expose `rr.Scalars`.
99
+ This wrapper returns an object suitable for `rr.log(path, ...)` for a single value.
100
+ """
101
+ v = float(value)
102
+
103
+ # New API (plural archetype)
104
+ if hasattr(rr, "Scalars"):
105
+ try:
106
+ return rr.Scalars(v)
107
+ except TypeError:
108
+ # Some versions expect a sequence/array for Scalars.
109
+ return rr.Scalars([v])
110
+
111
+ # Old API
112
+ if hasattr(rr, "Scalar"):
113
+ return rr.Scalar(v)
114
+
115
+ raise AttributeError("rerun has neither `Scalar` nor `Scalars` - please upgrade `rerun-sdk`.")
116
+
69
117
 
70
118
  def create_mock_train_config() -> TrainPipelineConfig:
71
119
  """Create a mock TrainPipelineConfig for dataset visualization.
@@ -123,6 +171,7 @@ def visualize_dataset(
123
171
  web_port: int = 9090,
124
172
  save: bool = False,
125
173
  output_dir: Path | None = None,
174
+ urdf: Path | None = None,
126
175
  ) -> Path | None:
127
176
  if save:
128
177
  assert output_dir is not None, (
@@ -153,6 +202,17 @@ def visualize_dataset(
153
202
  # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
154
203
  gc.collect()
155
204
 
205
+ if urdf:
206
+ rr.log_file_from_path(urdf, static=True)
207
+ urdf_tree = rr.urdf.UrdfTree.from_file_path(urdf)
208
+ urdf_joints = [jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"]
209
+ print(
210
+ "Assuming the dataset state dimensions correspond to URDF joints in order:\n",
211
+ "\n".join(f"{i:3d}: {jnt.name}" for i, jnt in enumerate(urdf_joints)),
212
+ )
213
+ else:
214
+ urdf_joints = []
215
+
156
216
  if mode == "distant":
157
217
  rr.serve_web_viewer(open_browser=False, web_port=web_port)
158
218
 
@@ -161,8 +221,8 @@ def visualize_dataset(
161
221
  for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
162
222
  # iterate over the batch
163
223
  for i in range(len(batch["index"])):
164
- rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
165
- rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
224
+ _rr_set_sequence("frame_index", batch["frame_index"][i].item())
225
+ _rr_set_seconds("timestamp", batch["timestamp"][i].item())
166
226
 
167
227
  # display each camera image
168
228
  for key in dataset.meta.camera_keys:
@@ -172,21 +232,27 @@ def visualize_dataset(
172
232
  # display each dimension of action space (e.g. actuators command)
173
233
  if "action" in batch:
174
234
  for dim_idx, val in enumerate(batch["action"][i]):
175
- rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
235
+ rr.log(f"action/{dim_idx}", _rr_scalar(val.item()))
176
236
 
177
237
  # display each dimension of observed state space (e.g. agent position in joint space)
178
238
  if "observation.state" in batch:
179
239
  for dim_idx, val in enumerate(batch["observation.state"][i]):
180
- rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
240
+ rr.log(f"state/{dim_idx}", _rr_scalar(val.item()))
241
+ # Assuming the state dimensions correspond to URDF joints in order.
242
+ # TODO(shuheng): allow overriding with a mapping from state dim to joint name.
243
+ if dim_idx < len(urdf_joints):
244
+ joint = urdf_joints[dim_idx]
245
+ transform = joint.compute_transform(float(val))
246
+ rr.log("URDF", transform)
181
247
 
182
248
  if "next.done" in batch:
183
- rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
249
+ rr.log("next.done", _rr_scalar(batch["next.done"][i].item()))
184
250
 
185
251
  if "next.reward" in batch:
186
- rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
252
+ rr.log("next.reward", _rr_scalar(batch["next.reward"][i].item()))
187
253
 
188
254
  if "next.success" in batch:
189
- rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
255
+ rr.log("next.success", _rr_scalar(batch["next.success"][i].item()))
190
256
 
191
257
  if mode == "local" and save:
192
258
  # save .rrd locally
@@ -272,7 +338,6 @@ def parse_args() -> dict:
272
338
  "Visualize the data by running `rerun path/to/file.rrd` on your local machine."
273
339
  ),
274
340
  )
275
-
276
341
  parser.add_argument(
277
342
  "--tolerance-s",
278
343
  type=float,
@@ -283,6 +348,22 @@ def parse_args() -> dict:
283
348
  "If not given, defaults to 1e-4."
284
349
  ),
285
350
  )
351
+ parser.add_argument(
352
+ "--urdf",
353
+ type=Path,
354
+ default=None,
355
+ help="Path to a URDF file to load and visualize alongside the dataset.",
356
+ )
357
+ parser.add_argument(
358
+ "--urdf-package-dir",
359
+ type=Path,
360
+ default=None,
361
+ help=(
362
+ "Root directory of the URDF package to resolve package:// paths. "
363
+ "You can also set the ROS_PACKAGE_PATH environment variable, "
364
+ "which will be used if this argument is not provided."
365
+ ),
366
+ )
286
367
 
287
368
  args = parser.parse_args()
288
369
  return vars(args)
@@ -293,6 +374,12 @@ def main():
293
374
  repo_id = kwargs.pop("repo_id")
294
375
  root = kwargs.pop("root")
295
376
  tolerance_s = kwargs.pop("tolerance_s")
377
+ urdf_package_dir = kwargs.pop("urdf_package_dir")
378
+ if urdf_package_dir:
379
+ os.environ["ROS_PACKAGE_PATH"] = urdf_package_dir.resolve().as_posix()
380
+
381
+ if not PERMIT_URDF:
382
+ kwargs["urdf"] = None
296
383
 
297
384
  logging.info("Loading dataset")
298
385
  dataset = LeRobotDataset(
@@ -167,7 +167,10 @@ class PatchedGemmaRMSNorm(nn.Module):
167
167
 
168
168
  def extra_repr(self) -> str:
169
169
  """Returns the extra representation of the module."""
170
- repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
170
+ if hasattr(self, "weight") and self.weight is not None:
171
+ repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
172
+ else:
173
+ repr_str = f"dim={self.dim}, eps={self.eps}"
171
174
  if self.dense is not None:
172
175
  repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
173
176
  return repr_str
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opentau
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch
5
- Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>, Xingrui Gu <xingrui_gu@berkeley.edu>
5
+ Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>
6
6
  License: Apache-2.0
7
7
  Project-URL: homepage, https://github.com/TensorAuto/OpenTau
8
8
  Project-URL: issues, https://github.com/TensorAuto/OpenTau/issues
@@ -41,9 +41,9 @@ Requires-Dist: pynput>=1.7.7
41
41
  Requires-Dist: pyzmq>=26.2.1
42
42
  Requires-Dist: rerun-sdk>=0.21.0
43
43
  Requires-Dist: termcolor>=2.4.0
44
- Requires-Dist: torch<2.8.0,>=2.7.1
44
+ Requires-Dist: torch>=2.7.1
45
45
  Requires-Dist: torchcodec<0.5.0,>=0.4.0; sys_platform != "win32" and (sys_platform != "linux" or (platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l")) and (sys_platform != "darwin" or platform_machine != "x86_64")
46
- Requires-Dist: torchvision<0.23.0,>=0.22.1
46
+ Requires-Dist: torchvision>=0.22.1
47
47
  Requires-Dist: wandb>=0.16.3
48
48
  Requires-Dist: zarr>=2.17.0
49
49
  Requires-Dist: scikit-learn>=1.7.1
@@ -62,6 +62,10 @@ Requires-Dist: scikit-image>=0.23.2
62
62
  Requires-Dist: pandas>=2.2.2
63
63
  Requires-Dist: accelerate>=1.4.0
64
64
  Requires-Dist: deepspeed>=0.17.1
65
+ Requires-Dist: gymnasium[other]>=0.29
66
+ Requires-Dist: grpcio>=1.60.0
67
+ Requires-Dist: grpcio-tools>=1.60.0
68
+ Requires-Dist: protobuf>=4.25.0
65
69
  Provides-Extra: dev
66
70
  Requires-Dist: pre-commit>=3.7.0; extra == "dev"
67
71
  Requires-Dist: debugpy>=1.8.1; extra == "dev"
@@ -93,10 +97,11 @@ Requires-Dist: libero; extra == "libero"
93
97
  Requires-Dist: numpy<2; extra == "libero"
94
98
  Requires-Dist: gym<0.27,>=0.25; extra == "libero"
95
99
  Requires-Dist: pyopengl-accelerate==3.1.7; sys_platform == "linux" and extra == "libero"
96
- Requires-Dist: gymnasium[other]>=0.29; extra == "libero"
97
100
  Requires-Dist: mujoco>=3.1.6; sys_platform == "linux" and extra == "libero"
98
101
  Requires-Dist: pyopengl==3.1.7; sys_platform == "linux" and extra == "libero"
99
102
  Requires-Dist: numpy==1.26.4; sys_platform == "linux" and extra == "libero"
103
+ Provides-Extra: urdf
104
+ Requires-Dist: rerun-sdk>=0.28.2; extra == "urdf"
100
105
  Dynamic: license-file
101
106
 
102
107
  <p align="center">
@@ -105,6 +110,19 @@ Dynamic: license-file
105
110
  </a>
106
111
  </p>
107
112
 
113
+ <p align="center">
114
+ <a href="https://github.com/TensorAuto/OpenTau/actions/workflows/cpu_test.yml?query=branch%3Amain"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/cpu_test.yml/badge.svg?branch=main" alt="CPU Tests"></a>
115
+ <a href="https://github.com/TensorAuto/OpenTau/actions/workflows/gpu_test.yml"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/gpu_test.yml/badge.svg" alt="Nightly GPU Tests"></a>
116
+ <a href="https://github.com/TensorAuto/OpenTau/actions/workflows/regression_test.yml"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/regression_test.yml/badge.svg" alt="Nightly Regression Tests"></a>
117
+ <a href="https://opentau.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/opentau/badge/?version=latest" alt="Documentation"></a>
118
+ <a href="https://pypi.org/project/opentau/"><img src="https://img.shields.io/pypi/v/opentau" alt="Version"></a>
119
+ <a href="https://pypi.org/project/opentau/"><img src="https://img.shields.io/pypi/status/opentau" alt="Status"></a>
120
+ <a href="https://www.python.org/downloads/"><img src="https://img.shields.io/pypi/pyversions/opentau" alt="Python versions"></a>
121
+ <a href="https://github.com/TensorAuto/OpenTau/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="License"></a>
122
+ <a href="https://hub.docker.com/r/tensorauto/opentau"><img src="https://img.shields.io/docker/v/tensorauto/opentau?label=Docker" alt="Docker"></a>
123
+ <a href="https://github.com/pre-commit/pre-commit"><img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit" alt="pre-commit"></a>
124
+ </p>
125
+
108
126
  # OpenTau - Train VLA models with state-of-the-art techniques by Tensor
109
127
 
110
128
  At Tensor, we are pushing the frontier of large foundation models for physical AI. In robot learning, a vision-language-action (VLA) model is a multimodal foundation model that integrates vision, language, and action. Today, VLA represents the leading approach for embodied AI, spanning autonomous driving, robot manipulation, and navigation.
@@ -122,17 +140,19 @@ Whether you use the official OpenPi codebase or LeRobot’s reimplementation, yo
122
140
 
123
141
  OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we also use it internally to train our proprietary in-house models. Our goal is to help you train VLAs on any dataset while fully leveraging state-of-the-art techniques. We plan to continuously upgrade this repository to keep pace with the state of the art in the robotics community.
124
142
 
125
- | Features | OpenPi | LeRobot | **OpenTau** |
126
- | -------------------------------------------------------: | :---------------------: | :------------------------------: | :---------: |
127
- | Co-training with Heterogeneous Datasets | | | |
128
- | Discrete Actions Training in $\pi_{0.5}$ | | | |
129
- | Knowledge Insulation (KI) between VLM and Action Decoder | | | |
130
- | Dropout Layers in PaliGemma | ✅ (Jax) <br>❌ (PyTorch) | | |
131
- | Multi-Node and Multi-GPU Training | | | |
132
- | Fully Functioning $\pi_{0.5}$ Checkpoint | | ❌ <br> (Missing Text Embeddings) | |
133
- | Simulation Environments for Evaluating Models | | | ✅ |
134
- | $\pi^{*}_{0.6}$ style Reinforcement Learning Pipeline | || |
135
- | Framework | Jax / PyTorch | PyTorch | PyTorch |
143
+ | Features | OpenPi | LeRobot | **OpenTau** |
144
+ |---------------------------------------------------------:|:-----------------------:|:--------------------------------:|:-----------:|
145
+ | Co-training with Heterogeneous Datasets | | | |
146
+ | Discrete Actions Training in $\pi_{0.5}$ | | | |
147
+ | Knowledge Insulation (KI) between VLM and Action Decoder | | | |
148
+ | Dropout Layers in PaliGemma | ✅ (Jax) <br>❌ (PyTorch) | | |
149
+ | Multi-Node and Multi-GPU Training | | | |
150
+ | Fully Functioning $\pi_{0.5}$ Checkpoint | | ❌ <br> (Missing Text Embeddings) | |
151
+ | Visualize dataset with URDF models | | | |
152
+ | Simulation Environments for Evaluating Models | || |
153
+ | Create Validation Splits During Training || ❌ | ✅ |
154
+ | $\pi^{*}_{0.6}$ style Reinforcement Learning Pipeline | ❌ | ❌ | ✅ |
155
+ | Framework | Jax / PyTorch | PyTorch | PyTorch |
136
156
 
137
157
  ## Quick Start
138
158
  If you are familiar with LeRobot, getting started with OpenTau is very easy.
@@ -2,26 +2,27 @@ opentau/__init__.py,sha256=fIqOYsZsF-bitUI-4taSNke_D1YJYCehGseZNe29GG0,6756
2
2
  opentau/__version__.py,sha256=junxoss59Jz_hmg3YnzhpVk_Q5Fo6uha23P1ET81N1c,889
3
3
  opentau/constants.py,sha256=-_CbJujCp6hbBjJHgYMguCTcSAkVkmdpM4wHqZp7vRQ,2020
4
4
  opentau/configs/__init__.py,sha256=hC-KkeCfq1mtMw9WjPCZfOTxrzQWW7hAa1w8BRC_Bqw,784
5
- opentau/configs/default.py,sha256=MEoZyzK8olVXVHDy3FHQKK43-ULq2UQv3almymbmdCI,14387
5
+ opentau/configs/default.py,sha256=T5QS84VNcoFaek8Ry7BsDAvsNsV5RLu2L2y2VvO3THo,15273
6
+ opentau/configs/deployment.py,sha256=kem4WCkAKV6H_JUFEVyzw4h-EhQTPa6-9pY65Wr9wDQ,3106
6
7
  opentau/configs/libero.py,sha256=CrRfiCBYOw7hVqv6orH_ahNyQudj6iyqHtZM9YpdvzE,4688
7
8
  opentau/configs/parser.py,sha256=Pb7sw6yx38F31Aqw1J7wK6BRzfBA7DutywShyP_t9bY,14890
8
9
  opentau/configs/policies.py,sha256=06oUJx0B4V6krRwyjH1goTYM3RIpRozg4SSwcVJurG4,11667
9
10
  opentau/configs/reward.py,sha256=t7S8_RpEy31fAP4v_ygB-ETvaUR6OyrXmS1JNSz3cOk,1537
10
- opentau/configs/train.py,sha256=a-c-s2zpCUkK8n0DqBPGGaJ87UcVNF02RcLa1pH843Y,18049
11
+ opentau/configs/train.py,sha256=nn2QX151wI-R-qbghyMkSv1miSPvNSUtsX2Swu-HVGU,18396
11
12
  opentau/configs/types.py,sha256=DvKasR2v0ecSmozL0YD4S-64OeuDYhVBhtspxUDV5u0,2453
12
13
  opentau/datasets/__init__.py,sha256=oLfV9vfFOg7o2XIRFiN5zOf529FafIkPwqFG7iUX4gc,4248
13
14
  opentau/datasets/backward_compatibility.py,sha256=ENVQk9QDPCip6NfAxNF6Vo4VvyCWb0acV-ZxcJBsB6o,3459
14
15
  opentau/datasets/compute_stats.py,sha256=N359TDuJicLKMtxxy0JVEcUtnTOB57gL5G8e9Dq0gMQ,13069
15
16
  opentau/datasets/dataset_mixture.py,sha256=8UWjY9oKn9jEMe-e9Dy6no1p_21H0kXKv8A10Ku_8_o,19850
16
- opentau/datasets/factory.py,sha256=NKWpbuNBve0PsmK1midj8g1IpQapeHn-VrxCOC3X4eI,10480
17
+ opentau/datasets/factory.py,sha256=KVN8XEjeIdfTMohyftghG3dsM50UPjv05lL3eS5aTI4,12116
17
18
  opentau/datasets/image_writer.py,sha256=JYCkImHFYpLuE88t16cYqXqQS7EHS7g6kLWXPCJmWgw,11072
18
- opentau/datasets/lerobot_dataset.py,sha256=c6bGOz75yEJfYkYqlcfszGkap0VBAMBFXrH8fz1P1WQ,84651
19
+ opentau/datasets/lerobot_dataset.py,sha256=f15Sy3jWOnuPiXiqB8pGdHqv3MOBZgToJyjcA7ry0JU,84778
19
20
  opentau/datasets/online_buffer.py,sha256=x14P8tBz25s-hRlE8loFJs5CAvh65RGWeogF271hiF0,19671
20
21
  opentau/datasets/sampler.py,sha256=5g-6prsWItVjqkt1J7mA9JPNQPhDSFx3r6rA4JphP9U,4012
21
22
  opentau/datasets/standard_data_format_mapping.py,sha256=wEKilksMUjJGeIhvyLuR9qhyhtiJMK1e1AzCkbyx-l4,9667
22
23
  opentau/datasets/transforms.py,sha256=pr_8vOEDUoWu7aOUdnI0_wgetsFuie3I2UYFrcStG1k,12976
23
24
  opentau/datasets/utils.py,sha256=bZ0Q8KPZMWe9fLdrqJqslgDxI9sa8uxqPQTxEyWwDKw,45062
24
- opentau/datasets/video_utils.py,sha256=NY20Et6SKWLdG4EjTNdXhpPqWEFON-UccIn_P2YukSQ,21810
25
+ opentau/datasets/video_utils.py,sha256=AUNUKr4IrDVetfYjzZj9Uq4GKrdHrzVYTcMoF1Jlggw,21968
25
26
  opentau/datasets/grounding/__init__.py,sha256=ojvdscCIjp5EQxptFAjPgvjKGZa_Xk9LLZ2wNUebWFw,3139
26
27
  opentau/datasets/grounding/base.py,sha256=FDAn2QPQHNBB7mzD7IQ2Bz882Dt1RPasBTgskXqKbP4,5773
27
28
  opentau/datasets/grounding/clevr.py,sha256=lNZ0hr5EhQKTh-eg35zujybcAo5p8JQEn9dW8DJhOjI,3983
@@ -59,9 +60,9 @@ opentau/policies/pi0/configuration_pi0.py,sha256=94EG2QlraDsPjD0zyuGwKPqToqV_ayP
59
60
  opentau/policies/pi0/modeling_pi0.py,sha256=rz1S7hDOVEv12sN0ECGupddKwQVXMqdvm7G_OooWZLA,37442
60
61
  opentau/policies/pi0/paligemma_with_expert.py,sha256=j9P6SL7MVP11MgyJqNnsrZAlOctWmqwDaJJAx4z9F84,20724
61
62
  opentau/policies/pi05/__init__.py,sha256=VcIjZwlRW1JChRHqAK9Vz4JAIEP40RrP-W-UdyR6xk4,821
62
- opentau/policies/pi05/configuration_pi05.py,sha256=ucgCC3BaIC6rcnotMYbElTq7ymPZh4xDGghxCseK33M,9307
63
- opentau/policies/pi05/modeling_pi05.py,sha256=sF4OPGQWQ0eJevGDRZDPUItBkBC5Wp6WdpwsN4uqS14,50639
64
- opentau/policies/pi05/paligemma_with_expert.py,sha256=nxBfUwBt6S4WwPDkn8LW3lHFcOib4CtoZzJEq4VKlck,23328
63
+ opentau/policies/pi05/configuration_pi05.py,sha256=GmENtmgvI5q3gQQlZnH6RalV5msU5gwtTK_imbNH6a8,9367
64
+ opentau/policies/pi05/modeling_pi05.py,sha256=sBmnP2cqaVHvCGGpg624LYbun52hZs63UCeadiE4dH4,63327
65
+ opentau/policies/pi05/paligemma_with_expert.py,sha256=jyYkcOVMxMNYJtmbR3qZPmEoinOY1myqnq1JKmledkc,23407
65
66
  opentau/policies/value/__init__.py,sha256=wUP5vdpsvfnREqt4enfwaakzJ-ynX9sLYN14t2zEtpA,772
66
67
  opentau/policies/value/configuration_value.py,sha256=ApjrNKHxvjNlSZ71-BPvanNAJh9GzAK9W0JCiz3mMHs,5951
67
68
  opentau/policies/value/modeling_value.py,sha256=21x2EVGFlJbLasJDGTs3b5YMrrChrkuGxrYP7wLjkCY,18594
@@ -78,13 +79,16 @@ opentau/scripts/fake_tensor_training.py,sha256=y4F3CFs2jjpIJcT1wKvsrgFEebU9QFzba
78
79
  opentau/scripts/get_advantage_and_percentiles.py,sha256=JdjlADYzdS1Jc_19H6lLYMRnPlWxeckRSUQqwqb0rC4,8993
79
80
  opentau/scripts/high_level_planner_inference.py,sha256=nbXr8Hp64YGeprMTpT8kvT_NgpBlI02CUlO6Mm2Js_E,3846
80
81
  opentau/scripts/inference.py,sha256=_lp9YjPzarAnjiA8k2jBlIKZxza6PEHw--UyaqLPdNo,2110
81
- opentau/scripts/launch.py,sha256=kcJtdO1WHYxiHSJpJ_y618tbIvBuGXy8FmH5BEEdVdI,2826
82
- opentau/scripts/libero_simulation_parallel.py,sha256=qMee6T0EwMoAT1J2u8X4w8rsbOJYwyqD3LRAPe2Ta1g,13105
83
- opentau/scripts/libero_simulation_sequential.py,sha256=xFSUQEuyai20QD-pYitp-UJPGE9zlaaIu4YSO0bhYKg,4775
82
+ opentau/scripts/launch.py,sha256=L_KlkcJpcOsSMGlBKSmtTUyzb7q8tH4FkmuHx8lEdDI,2845
84
83
  opentau/scripts/nav_high_level_planner_inference.py,sha256=z2WHw68NWi-fJUd5TV4CrJHzxo-L7e2UliGjfOlqifM,1878
85
- opentau/scripts/train.py,sha256=nkvsvna5yliphp7pwVyFY6yBwCA_kmffyohRO2wjiHU,16850
86
- opentau/scripts/visualize_dataset.py,sha256=RsON_13oqTm7HN14tGnDBIVAJPCW_-EJzpMHeiXxp24,10492
84
+ opentau/scripts/train.py,sha256=dp_366gKpFeIcv2tfDkuuFGCdPyP74lkENETZpwR5m4,20547
85
+ opentau/scripts/visualize_dataset.py,sha256=ZfB7Qbsl3RGqu8k7n6CK6iRbhuYhYSX167tta2b01NQ,13625
87
86
  opentau/scripts/zero_to_fp32.py,sha256=Rkl1ZczytKix9vGMg0EELzdJYFqUM1yB9p3xvSaK9k8,33272
87
+ opentau/scripts/grpc/__init__.py,sha256=wBiZyRqF1PCpZHgqwHjSZaaeFRHLOX4ZggCXbAzngOs,799
88
+ opentau/scripts/grpc/client.py,sha256=PbuAb14izNAItspdvppItpv6gaycBs2_mdcvrdtAXnQ,20181
89
+ opentau/scripts/grpc/robot_inference_pb2.py,sha256=Se7elLBIHeoi3UIBXrn02w8rZVOg-gpBcuirt23d8Tg,4125
90
+ opentau/scripts/grpc/robot_inference_pb2_grpc.py,sha256=LymJbsaei6C-pp9ffz2ETCVIhq4SRh--LdK8ne0-yug,7659
91
+ opentau/scripts/grpc/server.py,sha256=x6BA0F0uYIXeBw4mnOXaYb2-y8zEbgMXoiRdtNvyq1g,11146
88
92
  opentau/utils/__init__.py,sha256=hIUeGPpZHf2AVf0-5C2p0BOcY0cFHCTT5yHn-SpEPwY,856
89
93
  opentau/utils/accelerate_utils.py,sha256=vXnSGo1hXCUNof-oNKLMJ_SOMjpKhpZ1gx21ObSsopI,2630
90
94
  opentau/utils/benchmark.py,sha256=jVli6gdBRMXAqNM3AIi43a0N_O1CLQMbKXsPK_e2y3s,3063
@@ -98,11 +102,11 @@ opentau/utils/logging_utils.py,sha256=zd7ypmk7aqVposPhA7Kg-PYrstapY4MsuTklsTD4r4
98
102
  opentau/utils/monkey_patch.py,sha256=cVgZ1N-NNVnlRKPA1dwO9FM4IbxR0V_Hbil6p-6knhA,9558
99
103
  opentau/utils/random_utils.py,sha256=k3Ab3Y98LozGdsBzKoP8xSsFTcnaRqUzY34BsETCrrA,9102
100
104
  opentau/utils/train_utils.py,sha256=0d7yvk8wlP-75pwB55gr095b_b1sWG5nlqdVxyH6_o0,6796
101
- opentau/utils/transformers_patch.py,sha256=-3Fvf-_owtT_QDUkoGfMWO-pxN5xeQikPljtLMn4MRs,9906
105
+ opentau/utils/transformers_patch.py,sha256=rPG2Yn7GQr2gCEykhW42uOoKP_jdAMx4p3q-IUcGYDI,10045
102
106
  opentau/utils/utils.py,sha256=DrMStfjBEkw_8WVhYMnCQJNBxMeozIJ8LBSpOtMQhFM,15760
103
- opentau-0.1.2.dist-info/licenses/LICENSE,sha256=tl3_NkxplsgU86xSvEWnDlE1UR_JsIvGo7t4hPtsIbE,27680
104
- opentau-0.1.2.dist-info/METADATA,sha256=Up5VRGhf8RVjBA0mBy6xKA21-6R_t51xvGmG-YgC1EQ,10943
105
- opentau-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
106
- opentau-0.1.2.dist-info/entry_points.txt,sha256=NGF_MWpSKri0lvjR9WGN4pBUap8B-z21f7XMluxc1M4,208
107
- opentau-0.1.2.dist-info/top_level.txt,sha256=7_yrS4x5KSeTRr2LICTCNOZmF-1_kSOFPKHvtJPL1Dw,8
108
- opentau-0.1.2.dist-info/RECORD,,
107
+ opentau-0.2.0.dist-info/licenses/LICENSE,sha256=tl3_NkxplsgU86xSvEWnDlE1UR_JsIvGo7t4hPtsIbE,27680
108
+ opentau-0.2.0.dist-info/METADATA,sha256=OUN-KCCeTA1cFdlzmwRKecxepcmjpie6-AbbIkoELhI,12992
109
+ opentau-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
110
+ opentau-0.2.0.dist-info/entry_points.txt,sha256=NGF_MWpSKri0lvjR9WGN4pBUap8B-z21f7XMluxc1M4,208
111
+ opentau-0.2.0.dist-info/top_level.txt,sha256=7_yrS4x5KSeTRr2LICTCNOZmF-1_kSOFPKHvtJPL1Dw,8
112
+ opentau-0.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5