opentau 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/configs/default.py +16 -0
- opentau/configs/deployment.py +85 -0
- opentau/configs/train.py +5 -0
- opentau/datasets/factory.py +43 -10
- opentau/datasets/lerobot_dataset.py +12 -10
- opentau/datasets/video_utils.py +11 -6
- opentau/policies/pi05/configuration_pi05.py +9 -6
- opentau/policies/pi05/modeling_pi05.py +296 -30
- opentau/policies/pi05/paligemma_with_expert.py +20 -20
- opentau/scripts/grpc/__init__.py +19 -0
- opentau/scripts/grpc/client.py +601 -0
- opentau/scripts/grpc/robot_inference_pb2.py +61 -0
- opentau/scripts/grpc/robot_inference_pb2_grpc.py +210 -0
- opentau/scripts/grpc/server.py +313 -0
- opentau/scripts/launch.py +8 -5
- opentau/scripts/train.py +94 -17
- opentau/scripts/visualize_dataset.py +95 -8
- opentau/utils/transformers_patch.py +4 -1
- {opentau-0.1.2.dist-info → opentau-0.2.0.dist-info}/METADATA +36 -16
- {opentau-0.1.2.dist-info → opentau-0.2.0.dist-info}/RECORD +24 -20
- {opentau-0.1.2.dist-info → opentau-0.2.0.dist-info}/WHEEL +1 -1
- opentau/scripts/libero_simulation_parallel.py +0 -356
- opentau/scripts/libero_simulation_sequential.py +0 -122
- {opentau-0.1.2.dist-info → opentau-0.2.0.dist-info}/entry_points.txt +0 -0
- {opentau-0.1.2.dist-info → opentau-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {opentau-0.1.2.dist-info → opentau-0.2.0.dist-info}/top_level.txt +0 -0
opentau/scripts/train.py
CHANGED
|
@@ -17,6 +17,9 @@
|
|
|
17
17
|
import json
|
|
18
18
|
import logging
|
|
19
19
|
import os
|
|
20
|
+
|
|
21
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
22
|
+
|
|
20
23
|
from contextlib import nullcontext
|
|
21
24
|
from pprint import pformat
|
|
22
25
|
from typing import Any
|
|
@@ -176,7 +179,10 @@ def train(cfg: TrainPipelineConfig):
|
|
|
176
179
|
logging.info("Anomaly detection is disabled.")
|
|
177
180
|
|
|
178
181
|
logging.info("Creating dataset")
|
|
179
|
-
|
|
182
|
+
if cfg.val_freq > 0:
|
|
183
|
+
train_dataset, val_dataset = make_dataset_mixture(cfg)
|
|
184
|
+
else:
|
|
185
|
+
train_dataset = make_dataset_mixture(cfg)
|
|
180
186
|
|
|
181
187
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
|
182
188
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
|
@@ -188,7 +194,7 @@ def train(cfg: TrainPipelineConfig):
|
|
|
188
194
|
)
|
|
189
195
|
|
|
190
196
|
logging.info("Creating policy")
|
|
191
|
-
policy = make_policy(cfg=cfg.policy, ds_meta=
|
|
197
|
+
policy = make_policy(cfg=cfg.policy, ds_meta=train_dataset.meta)
|
|
192
198
|
policy.to(torch.bfloat16)
|
|
193
199
|
logging.info("Creating optimizer and scheduler")
|
|
194
200
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
|
@@ -203,11 +209,18 @@ def train(cfg: TrainPipelineConfig):
|
|
|
203
209
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
|
204
210
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
|
205
211
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
212
|
+
if cfg.val_freq > 0:
|
|
213
|
+
train_dataloader = train_dataset.get_dataloader()
|
|
214
|
+
val_dataloader = val_dataset.get_dataloader()
|
|
215
|
+
policy, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
|
|
216
|
+
policy, optimizer, train_dataloader, val_dataloader, lr_scheduler
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
train_dataloader = train_dataset.get_dataloader()
|
|
220
|
+
policy, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
221
|
+
policy, optimizer, train_dataloader, lr_scheduler
|
|
222
|
+
)
|
|
223
|
+
train_dl_iter = cycle(train_dataloader)
|
|
211
224
|
|
|
212
225
|
# Register the LR scheduler for checkpointing
|
|
213
226
|
accelerator.register_for_checkpointing(lr_scheduler)
|
|
@@ -246,7 +259,7 @@ def train(cfg: TrainPipelineConfig):
|
|
|
246
259
|
for _ in range(cfg.gradient_accumulation_steps):
|
|
247
260
|
with accelerator.accumulate(policy) if cfg.gradient_accumulation_steps > 1 else nullcontext():
|
|
248
261
|
logging.debug(f"{step=}, {accelerator.sync_gradients=}")
|
|
249
|
-
batch = next(
|
|
262
|
+
batch = next(train_dl_iter)
|
|
250
263
|
|
|
251
264
|
train_tracker = update_policy(
|
|
252
265
|
cfg,
|
|
@@ -266,20 +279,21 @@ def train(cfg: TrainPipelineConfig):
|
|
|
266
279
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
|
267
280
|
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and cfg.save_checkpoint
|
|
268
281
|
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
|
282
|
+
is_val_step = cfg.val_freq > 0 and step % cfg.val_freq == 0
|
|
269
283
|
|
|
270
284
|
# Only `train_tracker` on the main process keeps useful statistics,
|
|
271
285
|
# because we guarded it with if accelerator.is_main_process in the `update_policy` function.
|
|
272
286
|
if is_log_step and accelerator.is_main_process:
|
|
273
287
|
logging.info(train_tracker)
|
|
274
288
|
log_dict = train_tracker.to_dict(use_avg=True)
|
|
275
|
-
accelerator.log({"Training
|
|
276
|
-
accelerator.log({"MSE Loss": log_dict["mse_loss"]}, step=step)
|
|
277
|
-
accelerator.log({"CE Loss": log_dict["ce_loss"]}, step=step)
|
|
278
|
-
accelerator.log({"L1 Loss": log_dict["l1_loss"]}, step=step)
|
|
279
|
-
accelerator.log({"Accuracy": log_dict["accuracy"]}, step=step)
|
|
280
|
-
accelerator.log({"Learning Rate": log_dict["lr"]}, step=step)
|
|
281
|
-
accelerator.log({"Grad Norm": log_dict["grad_norm"]}, step=step)
|
|
282
|
-
accelerator.log({"Num Samples": log_dict["samples"]}, step=step)
|
|
289
|
+
accelerator.log({"Training/Loss": log_dict["loss"]}, step=step)
|
|
290
|
+
accelerator.log({"Training/MSE Loss": log_dict["mse_loss"]}, step=step)
|
|
291
|
+
accelerator.log({"Training/CE Loss": log_dict["ce_loss"]}, step=step)
|
|
292
|
+
accelerator.log({"Training/L1 Loss": log_dict["l1_loss"]}, step=step)
|
|
293
|
+
accelerator.log({"Training/Accuracy": log_dict["accuracy"]}, step=step)
|
|
294
|
+
accelerator.log({"Training/Learning Rate": log_dict["lr"]}, step=step)
|
|
295
|
+
accelerator.log({"Training/Grad Norm": log_dict["grad_norm"]}, step=step)
|
|
296
|
+
accelerator.log({"Training/Num Samples": log_dict["samples"]}, step=step)
|
|
283
297
|
train_tracker.reset_averages()
|
|
284
298
|
|
|
285
299
|
if is_saving_step:
|
|
@@ -299,6 +313,70 @@ def train(cfg: TrainPipelineConfig):
|
|
|
299
313
|
if cfg.last_checkpoint_only:
|
|
300
314
|
prune_old_checkpoints(checkpoint_dir)
|
|
301
315
|
|
|
316
|
+
accelerator.wait_for_everyone()
|
|
317
|
+
|
|
318
|
+
if is_val_step:
|
|
319
|
+
policy.eval()
|
|
320
|
+
val_metrics = {
|
|
321
|
+
"loss": AverageMeter("val_total_loss", ":.3f"),
|
|
322
|
+
"mse_loss": AverageMeter("val_mse_loss", ":.3f"),
|
|
323
|
+
"ce_loss": AverageMeter("val_ce_loss", ":.3f"),
|
|
324
|
+
"l1_loss": AverageMeter("val_l1_loss", ":.3f"),
|
|
325
|
+
"accuracy": AverageMeter("val_accuracy", ":.3f"),
|
|
326
|
+
}
|
|
327
|
+
val_tracker = MetricsTracker(
|
|
328
|
+
cfg.batch_size * accelerator.num_processes,
|
|
329
|
+
val_metrics,
|
|
330
|
+
initial_step=step,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
logging.info(f"Validation at step {step}...")
|
|
334
|
+
|
|
335
|
+
with torch.no_grad():
|
|
336
|
+
for batch in val_dataloader:
|
|
337
|
+
losses = policy.forward(batch)
|
|
338
|
+
loss = cfg.loss_weighting["MSE"] * losses["MSE"] + cfg.loss_weighting["CE"] * losses["CE"]
|
|
339
|
+
|
|
340
|
+
# Gather and average metrics across processes
|
|
341
|
+
_first_loss_tensor = next(lt for lt in losses.values() if isinstance(lt, torch.Tensor))
|
|
342
|
+
zero = torch.tensor(0.0, device=_first_loss_tensor.device, dtype=_first_loss_tensor.dtype)
|
|
343
|
+
|
|
344
|
+
loss = accelerator.gather_for_metrics(loss).mean().item()
|
|
345
|
+
mse_loss = (
|
|
346
|
+
accelerator.gather_for_metrics(losses["MSE"]).to(dtype=torch.float32).mean().item()
|
|
347
|
+
)
|
|
348
|
+
ce_loss = (
|
|
349
|
+
accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item()
|
|
350
|
+
)
|
|
351
|
+
l1_loss = (
|
|
352
|
+
accelerator.gather_for_metrics(losses.get("L1", zero))
|
|
353
|
+
.to(dtype=torch.float32)
|
|
354
|
+
.mean()
|
|
355
|
+
.item()
|
|
356
|
+
)
|
|
357
|
+
accuracy = (
|
|
358
|
+
accelerator.gather_for_metrics(losses.get("Accuracy", zero))
|
|
359
|
+
.to(dtype=torch.float32)
|
|
360
|
+
.mean()
|
|
361
|
+
.item()
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if accelerator.is_main_process:
|
|
365
|
+
val_tracker.loss = loss
|
|
366
|
+
val_tracker.mse_loss = mse_loss
|
|
367
|
+
val_tracker.ce_loss = ce_loss
|
|
368
|
+
val_tracker.l1_loss = l1_loss
|
|
369
|
+
val_tracker.accuracy = accuracy
|
|
370
|
+
|
|
371
|
+
if accelerator.is_main_process:
|
|
372
|
+
logging.info(val_tracker)
|
|
373
|
+
val_dict = val_tracker.to_dict(use_avg=True)
|
|
374
|
+
accelerator.log({"Validation/Loss": val_dict["loss"]}, step=step)
|
|
375
|
+
accelerator.log({"Validation/MSE Loss": val_dict["mse_loss"]}, step=step)
|
|
376
|
+
accelerator.log({"Validation/CE Loss": val_dict["ce_loss"]}, step=step)
|
|
377
|
+
accelerator.log({"Validation/L1 Loss": val_dict["l1_loss"]}, step=step)
|
|
378
|
+
accelerator.log({"Validation/Accuracy": val_dict["accuracy"]}, step=step)
|
|
379
|
+
|
|
302
380
|
# This barrier is probably necessary to ensure
|
|
303
381
|
# other processes wait for the main process to finish saving
|
|
304
382
|
accelerator.wait_for_everyone()
|
|
@@ -357,7 +435,6 @@ def train(cfg: TrainPipelineConfig):
|
|
|
357
435
|
with open(videos_dir / "eval_info.json", "w") as f:
|
|
358
436
|
json.dump(eval_info, f, indent=2)
|
|
359
437
|
|
|
360
|
-
if is_eval_step:
|
|
361
438
|
# This barrier is to ensure all processes finishes evaluation before the next training step
|
|
362
439
|
# Some processes might be slower than others
|
|
363
440
|
accelerator.wait_for_everyone()
|
|
@@ -52,7 +52,9 @@ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --mode di
|
|
|
52
52
|
import argparse
|
|
53
53
|
import gc
|
|
54
54
|
import logging
|
|
55
|
+
import os
|
|
55
56
|
import time
|
|
57
|
+
import warnings
|
|
56
58
|
from pathlib import Path
|
|
57
59
|
from typing import Iterator
|
|
58
60
|
|
|
@@ -66,6 +68,52 @@ from opentau.configs.default import DatasetMixtureConfig, WandBConfig
|
|
|
66
68
|
from opentau.configs.train import TrainPipelineConfig
|
|
67
69
|
from opentau.datasets.lerobot_dataset import LeRobotDataset
|
|
68
70
|
|
|
71
|
+
PERMIT_URDF = hasattr(rr, "urdf")
|
|
72
|
+
if not PERMIT_URDF:
|
|
73
|
+
warnings.warn(
|
|
74
|
+
"`rerun.urdf` module not found. Make sure you have rerun >= 0.28.2 installed. "
|
|
75
|
+
" One way to ensure this is to install OpenTau with the '[urdf]' extra: `pip install opentau[urdf]`.",
|
|
76
|
+
stacklevel=2,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Older and newer versions of rerun have different APIs for setting time / sequence
|
|
81
|
+
def _rr_set_sequence(timeline: str, value: int):
|
|
82
|
+
if hasattr(rr, "set_time_sequence"):
|
|
83
|
+
rr.set_time_sequence(timeline, value)
|
|
84
|
+
else:
|
|
85
|
+
rr.set_time(timeline, sequence=value)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _rr_set_seconds(timeline: str, value: float):
|
|
89
|
+
if hasattr(rr, "set_time_seconds"):
|
|
90
|
+
rr.set_time_seconds(timeline, value)
|
|
91
|
+
else:
|
|
92
|
+
rr.set_time(timeline, timestamp=value)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _rr_scalar(value: float):
|
|
96
|
+
"""Return a rerun scalar archetype that works across rerun versions.
|
|
97
|
+
|
|
98
|
+
Older rerun versions expose `rr.Scalar`, while newer versions expose `rr.Scalars`.
|
|
99
|
+
This wrapper returns an object suitable for `rr.log(path, ...)` for a single value.
|
|
100
|
+
"""
|
|
101
|
+
v = float(value)
|
|
102
|
+
|
|
103
|
+
# New API (plural archetype)
|
|
104
|
+
if hasattr(rr, "Scalars"):
|
|
105
|
+
try:
|
|
106
|
+
return rr.Scalars(v)
|
|
107
|
+
except TypeError:
|
|
108
|
+
# Some versions expect a sequence/array for Scalars.
|
|
109
|
+
return rr.Scalars([v])
|
|
110
|
+
|
|
111
|
+
# Old API
|
|
112
|
+
if hasattr(rr, "Scalar"):
|
|
113
|
+
return rr.Scalar(v)
|
|
114
|
+
|
|
115
|
+
raise AttributeError("rerun has neither `Scalar` nor `Scalars` - please upgrade `rerun-sdk`.")
|
|
116
|
+
|
|
69
117
|
|
|
70
118
|
def create_mock_train_config() -> TrainPipelineConfig:
|
|
71
119
|
"""Create a mock TrainPipelineConfig for dataset visualization.
|
|
@@ -123,6 +171,7 @@ def visualize_dataset(
|
|
|
123
171
|
web_port: int = 9090,
|
|
124
172
|
save: bool = False,
|
|
125
173
|
output_dir: Path | None = None,
|
|
174
|
+
urdf: Path | None = None,
|
|
126
175
|
) -> Path | None:
|
|
127
176
|
if save:
|
|
128
177
|
assert output_dir is not None, (
|
|
@@ -153,6 +202,17 @@ def visualize_dataset(
|
|
|
153
202
|
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
|
154
203
|
gc.collect()
|
|
155
204
|
|
|
205
|
+
if urdf:
|
|
206
|
+
rr.log_file_from_path(urdf, static=True)
|
|
207
|
+
urdf_tree = rr.urdf.UrdfTree.from_file_path(urdf)
|
|
208
|
+
urdf_joints = [jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"]
|
|
209
|
+
print(
|
|
210
|
+
"Assuming the dataset state dimensions correspond to URDF joints in order:\n",
|
|
211
|
+
"\n".join(f"{i:3d}: {jnt.name}" for i, jnt in enumerate(urdf_joints)),
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
urdf_joints = []
|
|
215
|
+
|
|
156
216
|
if mode == "distant":
|
|
157
217
|
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
|
158
218
|
|
|
@@ -161,8 +221,8 @@ def visualize_dataset(
|
|
|
161
221
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
|
162
222
|
# iterate over the batch
|
|
163
223
|
for i in range(len(batch["index"])):
|
|
164
|
-
|
|
165
|
-
|
|
224
|
+
_rr_set_sequence("frame_index", batch["frame_index"][i].item())
|
|
225
|
+
_rr_set_seconds("timestamp", batch["timestamp"][i].item())
|
|
166
226
|
|
|
167
227
|
# display each camera image
|
|
168
228
|
for key in dataset.meta.camera_keys:
|
|
@@ -172,21 +232,27 @@ def visualize_dataset(
|
|
|
172
232
|
# display each dimension of action space (e.g. actuators command)
|
|
173
233
|
if "action" in batch:
|
|
174
234
|
for dim_idx, val in enumerate(batch["action"][i]):
|
|
175
|
-
rr.log(f"action/{dim_idx}",
|
|
235
|
+
rr.log(f"action/{dim_idx}", _rr_scalar(val.item()))
|
|
176
236
|
|
|
177
237
|
# display each dimension of observed state space (e.g. agent position in joint space)
|
|
178
238
|
if "observation.state" in batch:
|
|
179
239
|
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
|
180
|
-
rr.log(f"state/{dim_idx}",
|
|
240
|
+
rr.log(f"state/{dim_idx}", _rr_scalar(val.item()))
|
|
241
|
+
# Assuming the state dimensions correspond to URDF joints in order.
|
|
242
|
+
# TODO(shuheng): allow overriding with a mapping from state dim to joint name.
|
|
243
|
+
if dim_idx < len(urdf_joints):
|
|
244
|
+
joint = urdf_joints[dim_idx]
|
|
245
|
+
transform = joint.compute_transform(float(val))
|
|
246
|
+
rr.log("URDF", transform)
|
|
181
247
|
|
|
182
248
|
if "next.done" in batch:
|
|
183
|
-
rr.log("next.done",
|
|
249
|
+
rr.log("next.done", _rr_scalar(batch["next.done"][i].item()))
|
|
184
250
|
|
|
185
251
|
if "next.reward" in batch:
|
|
186
|
-
rr.log("next.reward",
|
|
252
|
+
rr.log("next.reward", _rr_scalar(batch["next.reward"][i].item()))
|
|
187
253
|
|
|
188
254
|
if "next.success" in batch:
|
|
189
|
-
rr.log("next.success",
|
|
255
|
+
rr.log("next.success", _rr_scalar(batch["next.success"][i].item()))
|
|
190
256
|
|
|
191
257
|
if mode == "local" and save:
|
|
192
258
|
# save .rrd locally
|
|
@@ -272,7 +338,6 @@ def parse_args() -> dict:
|
|
|
272
338
|
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
|
273
339
|
),
|
|
274
340
|
)
|
|
275
|
-
|
|
276
341
|
parser.add_argument(
|
|
277
342
|
"--tolerance-s",
|
|
278
343
|
type=float,
|
|
@@ -283,6 +348,22 @@ def parse_args() -> dict:
|
|
|
283
348
|
"If not given, defaults to 1e-4."
|
|
284
349
|
),
|
|
285
350
|
)
|
|
351
|
+
parser.add_argument(
|
|
352
|
+
"--urdf",
|
|
353
|
+
type=Path,
|
|
354
|
+
default=None,
|
|
355
|
+
help="Path to a URDF file to load and visualize alongside the dataset.",
|
|
356
|
+
)
|
|
357
|
+
parser.add_argument(
|
|
358
|
+
"--urdf-package-dir",
|
|
359
|
+
type=Path,
|
|
360
|
+
default=None,
|
|
361
|
+
help=(
|
|
362
|
+
"Root directory of the URDF package to resolve package:// paths. "
|
|
363
|
+
"You can also set the ROS_PACKAGE_PATH environment variable, "
|
|
364
|
+
"which will be used if this argument is not provided."
|
|
365
|
+
),
|
|
366
|
+
)
|
|
286
367
|
|
|
287
368
|
args = parser.parse_args()
|
|
288
369
|
return vars(args)
|
|
@@ -293,6 +374,12 @@ def main():
|
|
|
293
374
|
repo_id = kwargs.pop("repo_id")
|
|
294
375
|
root = kwargs.pop("root")
|
|
295
376
|
tolerance_s = kwargs.pop("tolerance_s")
|
|
377
|
+
urdf_package_dir = kwargs.pop("urdf_package_dir")
|
|
378
|
+
if urdf_package_dir:
|
|
379
|
+
os.environ["ROS_PACKAGE_PATH"] = urdf_package_dir.resolve().as_posix()
|
|
380
|
+
|
|
381
|
+
if not PERMIT_URDF:
|
|
382
|
+
kwargs["urdf"] = None
|
|
296
383
|
|
|
297
384
|
logging.info("Loading dataset")
|
|
298
385
|
dataset = LeRobotDataset(
|
|
@@ -167,7 +167,10 @@ class PatchedGemmaRMSNorm(nn.Module):
|
|
|
167
167
|
|
|
168
168
|
def extra_repr(self) -> str:
|
|
169
169
|
"""Returns the extra representation of the module."""
|
|
170
|
-
|
|
170
|
+
if hasattr(self, "weight") and self.weight is not None:
|
|
171
|
+
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
|
|
172
|
+
else:
|
|
173
|
+
repr_str = f"dim={self.dim}, eps={self.eps}"
|
|
171
174
|
if self.dense is not None:
|
|
172
175
|
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
|
|
173
176
|
return repr_str
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opentau
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch
|
|
5
|
-
Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com
|
|
5
|
+
Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>
|
|
6
6
|
License: Apache-2.0
|
|
7
7
|
Project-URL: homepage, https://github.com/TensorAuto/OpenTau
|
|
8
8
|
Project-URL: issues, https://github.com/TensorAuto/OpenTau/issues
|
|
@@ -41,9 +41,9 @@ Requires-Dist: pynput>=1.7.7
|
|
|
41
41
|
Requires-Dist: pyzmq>=26.2.1
|
|
42
42
|
Requires-Dist: rerun-sdk>=0.21.0
|
|
43
43
|
Requires-Dist: termcolor>=2.4.0
|
|
44
|
-
Requires-Dist: torch
|
|
44
|
+
Requires-Dist: torch>=2.7.1
|
|
45
45
|
Requires-Dist: torchcodec<0.5.0,>=0.4.0; sys_platform != "win32" and (sys_platform != "linux" or (platform_machine != "aarch64" and platform_machine != "arm64" and platform_machine != "armv7l")) and (sys_platform != "darwin" or platform_machine != "x86_64")
|
|
46
|
-
Requires-Dist: torchvision
|
|
46
|
+
Requires-Dist: torchvision>=0.22.1
|
|
47
47
|
Requires-Dist: wandb>=0.16.3
|
|
48
48
|
Requires-Dist: zarr>=2.17.0
|
|
49
49
|
Requires-Dist: scikit-learn>=1.7.1
|
|
@@ -62,6 +62,10 @@ Requires-Dist: scikit-image>=0.23.2
|
|
|
62
62
|
Requires-Dist: pandas>=2.2.2
|
|
63
63
|
Requires-Dist: accelerate>=1.4.0
|
|
64
64
|
Requires-Dist: deepspeed>=0.17.1
|
|
65
|
+
Requires-Dist: gymnasium[other]>=0.29
|
|
66
|
+
Requires-Dist: grpcio>=1.60.0
|
|
67
|
+
Requires-Dist: grpcio-tools>=1.60.0
|
|
68
|
+
Requires-Dist: protobuf>=4.25.0
|
|
65
69
|
Provides-Extra: dev
|
|
66
70
|
Requires-Dist: pre-commit>=3.7.0; extra == "dev"
|
|
67
71
|
Requires-Dist: debugpy>=1.8.1; extra == "dev"
|
|
@@ -93,10 +97,11 @@ Requires-Dist: libero; extra == "libero"
|
|
|
93
97
|
Requires-Dist: numpy<2; extra == "libero"
|
|
94
98
|
Requires-Dist: gym<0.27,>=0.25; extra == "libero"
|
|
95
99
|
Requires-Dist: pyopengl-accelerate==3.1.7; sys_platform == "linux" and extra == "libero"
|
|
96
|
-
Requires-Dist: gymnasium[other]>=0.29; extra == "libero"
|
|
97
100
|
Requires-Dist: mujoco>=3.1.6; sys_platform == "linux" and extra == "libero"
|
|
98
101
|
Requires-Dist: pyopengl==3.1.7; sys_platform == "linux" and extra == "libero"
|
|
99
102
|
Requires-Dist: numpy==1.26.4; sys_platform == "linux" and extra == "libero"
|
|
103
|
+
Provides-Extra: urdf
|
|
104
|
+
Requires-Dist: rerun-sdk>=0.28.2; extra == "urdf"
|
|
100
105
|
Dynamic: license-file
|
|
101
106
|
|
|
102
107
|
<p align="center">
|
|
@@ -105,6 +110,19 @@ Dynamic: license-file
|
|
|
105
110
|
</a>
|
|
106
111
|
</p>
|
|
107
112
|
|
|
113
|
+
<p align="center">
|
|
114
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/cpu_test.yml?query=branch%3Amain"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/cpu_test.yml/badge.svg?branch=main" alt="CPU Tests"></a>
|
|
115
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/gpu_test.yml"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/gpu_test.yml/badge.svg" alt="Nightly GPU Tests"></a>
|
|
116
|
+
<a href="https://github.com/TensorAuto/OpenTau/actions/workflows/regression_test.yml"><img src="https://github.com/TensorAuto/OpenTau/actions/workflows/regression_test.yml/badge.svg" alt="Nightly Regression Tests"></a>
|
|
117
|
+
<a href="https://opentau.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/opentau/badge/?version=latest" alt="Documentation"></a>
|
|
118
|
+
<a href="https://pypi.org/project/opentau/"><img src="https://img.shields.io/pypi/v/opentau" alt="Version"></a>
|
|
119
|
+
<a href="https://pypi.org/project/opentau/"><img src="https://img.shields.io/pypi/status/opentau" alt="Status"></a>
|
|
120
|
+
<a href="https://www.python.org/downloads/"><img src="https://img.shields.io/pypi/pyversions/opentau" alt="Python versions"></a>
|
|
121
|
+
<a href="https://github.com/TensorAuto/OpenTau/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="License"></a>
|
|
122
|
+
<a href="https://hub.docker.com/r/tensorauto/opentau"><img src="https://img.shields.io/docker/v/tensorauto/opentau?label=Docker" alt="Docker"></a>
|
|
123
|
+
<a href="https://github.com/pre-commit/pre-commit"><img src="https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit" alt="pre-commit"></a>
|
|
124
|
+
</p>
|
|
125
|
+
|
|
108
126
|
# OpenTau - Train VLA models with state-of-the-art techniques by Tensor
|
|
109
127
|
|
|
110
128
|
At Tensor, we are pushing the frontier of large foundation models for physical AI. In robot learning, a vision-language-action (VLA) model is a multimodal foundation model that integrates vision, language, and action. Today, VLA represents the leading approach for embodied AI, spanning autonomous driving, robot manipulation, and navigation.
|
|
@@ -122,17 +140,19 @@ Whether you use the official OpenPi codebase or LeRobot’s reimplementation, yo
|
|
|
122
140
|
|
|
123
141
|
OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we also use it internally to train our proprietary in-house models. Our goal is to help you train VLAs on any dataset while fully leveraging state-of-the-art techniques. We plan to continuously upgrade this repository to keep pace with the state of the art in the robotics community.
|
|
124
142
|
|
|
125
|
-
|
|
|
126
|
-
|
|
127
|
-
|
|
|
128
|
-
|
|
|
129
|
-
| Knowledge Insulation (KI) between VLM and Action Decoder |
|
|
130
|
-
|
|
|
131
|
-
|
|
|
132
|
-
|
|
|
133
|
-
|
|
|
134
|
-
|
|
|
135
|
-
|
|
|
143
|
+
| Features | OpenPi | LeRobot | **OpenTau** |
|
|
144
|
+
|---------------------------------------------------------:|:-----------------------:|:--------------------------------:|:-----------:|
|
|
145
|
+
| Co-training with Heterogeneous Datasets | ❌ | ❌ | ✅ |
|
|
146
|
+
| Discrete Actions Training in $\pi_{0.5}$ | ❌ | ❌ | ✅ |
|
|
147
|
+
| Knowledge Insulation (KI) between VLM and Action Decoder | ❌ | ❌ | ✅ |
|
|
148
|
+
| Dropout Layers in PaliGemma | ✅ (Jax) <br>❌ (PyTorch) | ❌ | ✅ |
|
|
149
|
+
| Multi-Node and Multi-GPU Training | ❌ | ✅ | ✅ |
|
|
150
|
+
| Fully Functioning $\pi_{0.5}$ Checkpoint | ✅ | ❌ <br> (Missing Text Embeddings) | ✅ |
|
|
151
|
+
| Visualize dataset with URDF models | ❌ | ❌ | ✅ |
|
|
152
|
+
| Simulation Environments for Evaluating Models | ❌ | ✅ | ✅ |
|
|
153
|
+
| Create Validation Splits During Training | ❌ | ❌ | ✅ |
|
|
154
|
+
| $\pi^{*}_{0.6}$ style Reinforcement Learning Pipeline | ❌ | ❌ | ✅ |
|
|
155
|
+
| Framework | Jax / PyTorch | PyTorch | PyTorch |
|
|
136
156
|
|
|
137
157
|
## Quick Start
|
|
138
158
|
If you are familiar with LeRobot, getting started with OpenTau is very easy.
|
|
@@ -2,26 +2,27 @@ opentau/__init__.py,sha256=fIqOYsZsF-bitUI-4taSNke_D1YJYCehGseZNe29GG0,6756
|
|
|
2
2
|
opentau/__version__.py,sha256=junxoss59Jz_hmg3YnzhpVk_Q5Fo6uha23P1ET81N1c,889
|
|
3
3
|
opentau/constants.py,sha256=-_CbJujCp6hbBjJHgYMguCTcSAkVkmdpM4wHqZp7vRQ,2020
|
|
4
4
|
opentau/configs/__init__.py,sha256=hC-KkeCfq1mtMw9WjPCZfOTxrzQWW7hAa1w8BRC_Bqw,784
|
|
5
|
-
opentau/configs/default.py,sha256=
|
|
5
|
+
opentau/configs/default.py,sha256=T5QS84VNcoFaek8Ry7BsDAvsNsV5RLu2L2y2VvO3THo,15273
|
|
6
|
+
opentau/configs/deployment.py,sha256=kem4WCkAKV6H_JUFEVyzw4h-EhQTPa6-9pY65Wr9wDQ,3106
|
|
6
7
|
opentau/configs/libero.py,sha256=CrRfiCBYOw7hVqv6orH_ahNyQudj6iyqHtZM9YpdvzE,4688
|
|
7
8
|
opentau/configs/parser.py,sha256=Pb7sw6yx38F31Aqw1J7wK6BRzfBA7DutywShyP_t9bY,14890
|
|
8
9
|
opentau/configs/policies.py,sha256=06oUJx0B4V6krRwyjH1goTYM3RIpRozg4SSwcVJurG4,11667
|
|
9
10
|
opentau/configs/reward.py,sha256=t7S8_RpEy31fAP4v_ygB-ETvaUR6OyrXmS1JNSz3cOk,1537
|
|
10
|
-
opentau/configs/train.py,sha256=
|
|
11
|
+
opentau/configs/train.py,sha256=nn2QX151wI-R-qbghyMkSv1miSPvNSUtsX2Swu-HVGU,18396
|
|
11
12
|
opentau/configs/types.py,sha256=DvKasR2v0ecSmozL0YD4S-64OeuDYhVBhtspxUDV5u0,2453
|
|
12
13
|
opentau/datasets/__init__.py,sha256=oLfV9vfFOg7o2XIRFiN5zOf529FafIkPwqFG7iUX4gc,4248
|
|
13
14
|
opentau/datasets/backward_compatibility.py,sha256=ENVQk9QDPCip6NfAxNF6Vo4VvyCWb0acV-ZxcJBsB6o,3459
|
|
14
15
|
opentau/datasets/compute_stats.py,sha256=N359TDuJicLKMtxxy0JVEcUtnTOB57gL5G8e9Dq0gMQ,13069
|
|
15
16
|
opentau/datasets/dataset_mixture.py,sha256=8UWjY9oKn9jEMe-e9Dy6no1p_21H0kXKv8A10Ku_8_o,19850
|
|
16
|
-
opentau/datasets/factory.py,sha256=
|
|
17
|
+
opentau/datasets/factory.py,sha256=KVN8XEjeIdfTMohyftghG3dsM50UPjv05lL3eS5aTI4,12116
|
|
17
18
|
opentau/datasets/image_writer.py,sha256=JYCkImHFYpLuE88t16cYqXqQS7EHS7g6kLWXPCJmWgw,11072
|
|
18
|
-
opentau/datasets/lerobot_dataset.py,sha256=
|
|
19
|
+
opentau/datasets/lerobot_dataset.py,sha256=f15Sy3jWOnuPiXiqB8pGdHqv3MOBZgToJyjcA7ry0JU,84778
|
|
19
20
|
opentau/datasets/online_buffer.py,sha256=x14P8tBz25s-hRlE8loFJs5CAvh65RGWeogF271hiF0,19671
|
|
20
21
|
opentau/datasets/sampler.py,sha256=5g-6prsWItVjqkt1J7mA9JPNQPhDSFx3r6rA4JphP9U,4012
|
|
21
22
|
opentau/datasets/standard_data_format_mapping.py,sha256=wEKilksMUjJGeIhvyLuR9qhyhtiJMK1e1AzCkbyx-l4,9667
|
|
22
23
|
opentau/datasets/transforms.py,sha256=pr_8vOEDUoWu7aOUdnI0_wgetsFuie3I2UYFrcStG1k,12976
|
|
23
24
|
opentau/datasets/utils.py,sha256=bZ0Q8KPZMWe9fLdrqJqslgDxI9sa8uxqPQTxEyWwDKw,45062
|
|
24
|
-
opentau/datasets/video_utils.py,sha256=
|
|
25
|
+
opentau/datasets/video_utils.py,sha256=AUNUKr4IrDVetfYjzZj9Uq4GKrdHrzVYTcMoF1Jlggw,21968
|
|
25
26
|
opentau/datasets/grounding/__init__.py,sha256=ojvdscCIjp5EQxptFAjPgvjKGZa_Xk9LLZ2wNUebWFw,3139
|
|
26
27
|
opentau/datasets/grounding/base.py,sha256=FDAn2QPQHNBB7mzD7IQ2Bz882Dt1RPasBTgskXqKbP4,5773
|
|
27
28
|
opentau/datasets/grounding/clevr.py,sha256=lNZ0hr5EhQKTh-eg35zujybcAo5p8JQEn9dW8DJhOjI,3983
|
|
@@ -59,9 +60,9 @@ opentau/policies/pi0/configuration_pi0.py,sha256=94EG2QlraDsPjD0zyuGwKPqToqV_ayP
|
|
|
59
60
|
opentau/policies/pi0/modeling_pi0.py,sha256=rz1S7hDOVEv12sN0ECGupddKwQVXMqdvm7G_OooWZLA,37442
|
|
60
61
|
opentau/policies/pi0/paligemma_with_expert.py,sha256=j9P6SL7MVP11MgyJqNnsrZAlOctWmqwDaJJAx4z9F84,20724
|
|
61
62
|
opentau/policies/pi05/__init__.py,sha256=VcIjZwlRW1JChRHqAK9Vz4JAIEP40RrP-W-UdyR6xk4,821
|
|
62
|
-
opentau/policies/pi05/configuration_pi05.py,sha256=
|
|
63
|
-
opentau/policies/pi05/modeling_pi05.py,sha256=
|
|
64
|
-
opentau/policies/pi05/paligemma_with_expert.py,sha256=
|
|
63
|
+
opentau/policies/pi05/configuration_pi05.py,sha256=GmENtmgvI5q3gQQlZnH6RalV5msU5gwtTK_imbNH6a8,9367
|
|
64
|
+
opentau/policies/pi05/modeling_pi05.py,sha256=sBmnP2cqaVHvCGGpg624LYbun52hZs63UCeadiE4dH4,63327
|
|
65
|
+
opentau/policies/pi05/paligemma_with_expert.py,sha256=jyYkcOVMxMNYJtmbR3qZPmEoinOY1myqnq1JKmledkc,23407
|
|
65
66
|
opentau/policies/value/__init__.py,sha256=wUP5vdpsvfnREqt4enfwaakzJ-ynX9sLYN14t2zEtpA,772
|
|
66
67
|
opentau/policies/value/configuration_value.py,sha256=ApjrNKHxvjNlSZ71-BPvanNAJh9GzAK9W0JCiz3mMHs,5951
|
|
67
68
|
opentau/policies/value/modeling_value.py,sha256=21x2EVGFlJbLasJDGTs3b5YMrrChrkuGxrYP7wLjkCY,18594
|
|
@@ -78,13 +79,16 @@ opentau/scripts/fake_tensor_training.py,sha256=y4F3CFs2jjpIJcT1wKvsrgFEebU9QFzba
|
|
|
78
79
|
opentau/scripts/get_advantage_and_percentiles.py,sha256=JdjlADYzdS1Jc_19H6lLYMRnPlWxeckRSUQqwqb0rC4,8993
|
|
79
80
|
opentau/scripts/high_level_planner_inference.py,sha256=nbXr8Hp64YGeprMTpT8kvT_NgpBlI02CUlO6Mm2Js_E,3846
|
|
80
81
|
opentau/scripts/inference.py,sha256=_lp9YjPzarAnjiA8k2jBlIKZxza6PEHw--UyaqLPdNo,2110
|
|
81
|
-
opentau/scripts/launch.py,sha256=
|
|
82
|
-
opentau/scripts/libero_simulation_parallel.py,sha256=qMee6T0EwMoAT1J2u8X4w8rsbOJYwyqD3LRAPe2Ta1g,13105
|
|
83
|
-
opentau/scripts/libero_simulation_sequential.py,sha256=xFSUQEuyai20QD-pYitp-UJPGE9zlaaIu4YSO0bhYKg,4775
|
|
82
|
+
opentau/scripts/launch.py,sha256=L_KlkcJpcOsSMGlBKSmtTUyzb7q8tH4FkmuHx8lEdDI,2845
|
|
84
83
|
opentau/scripts/nav_high_level_planner_inference.py,sha256=z2WHw68NWi-fJUd5TV4CrJHzxo-L7e2UliGjfOlqifM,1878
|
|
85
|
-
opentau/scripts/train.py,sha256=
|
|
86
|
-
opentau/scripts/visualize_dataset.py,sha256=
|
|
84
|
+
opentau/scripts/train.py,sha256=dp_366gKpFeIcv2tfDkuuFGCdPyP74lkENETZpwR5m4,20547
|
|
85
|
+
opentau/scripts/visualize_dataset.py,sha256=ZfB7Qbsl3RGqu8k7n6CK6iRbhuYhYSX167tta2b01NQ,13625
|
|
87
86
|
opentau/scripts/zero_to_fp32.py,sha256=Rkl1ZczytKix9vGMg0EELzdJYFqUM1yB9p3xvSaK9k8,33272
|
|
87
|
+
opentau/scripts/grpc/__init__.py,sha256=wBiZyRqF1PCpZHgqwHjSZaaeFRHLOX4ZggCXbAzngOs,799
|
|
88
|
+
opentau/scripts/grpc/client.py,sha256=PbuAb14izNAItspdvppItpv6gaycBs2_mdcvrdtAXnQ,20181
|
|
89
|
+
opentau/scripts/grpc/robot_inference_pb2.py,sha256=Se7elLBIHeoi3UIBXrn02w8rZVOg-gpBcuirt23d8Tg,4125
|
|
90
|
+
opentau/scripts/grpc/robot_inference_pb2_grpc.py,sha256=LymJbsaei6C-pp9ffz2ETCVIhq4SRh--LdK8ne0-yug,7659
|
|
91
|
+
opentau/scripts/grpc/server.py,sha256=x6BA0F0uYIXeBw4mnOXaYb2-y8zEbgMXoiRdtNvyq1g,11146
|
|
88
92
|
opentau/utils/__init__.py,sha256=hIUeGPpZHf2AVf0-5C2p0BOcY0cFHCTT5yHn-SpEPwY,856
|
|
89
93
|
opentau/utils/accelerate_utils.py,sha256=vXnSGo1hXCUNof-oNKLMJ_SOMjpKhpZ1gx21ObSsopI,2630
|
|
90
94
|
opentau/utils/benchmark.py,sha256=jVli6gdBRMXAqNM3AIi43a0N_O1CLQMbKXsPK_e2y3s,3063
|
|
@@ -98,11 +102,11 @@ opentau/utils/logging_utils.py,sha256=zd7ypmk7aqVposPhA7Kg-PYrstapY4MsuTklsTD4r4
|
|
|
98
102
|
opentau/utils/monkey_patch.py,sha256=cVgZ1N-NNVnlRKPA1dwO9FM4IbxR0V_Hbil6p-6knhA,9558
|
|
99
103
|
opentau/utils/random_utils.py,sha256=k3Ab3Y98LozGdsBzKoP8xSsFTcnaRqUzY34BsETCrrA,9102
|
|
100
104
|
opentau/utils/train_utils.py,sha256=0d7yvk8wlP-75pwB55gr095b_b1sWG5nlqdVxyH6_o0,6796
|
|
101
|
-
opentau/utils/transformers_patch.py,sha256
|
|
105
|
+
opentau/utils/transformers_patch.py,sha256=rPG2Yn7GQr2gCEykhW42uOoKP_jdAMx4p3q-IUcGYDI,10045
|
|
102
106
|
opentau/utils/utils.py,sha256=DrMStfjBEkw_8WVhYMnCQJNBxMeozIJ8LBSpOtMQhFM,15760
|
|
103
|
-
opentau-0.
|
|
104
|
-
opentau-0.
|
|
105
|
-
opentau-0.
|
|
106
|
-
opentau-0.
|
|
107
|
-
opentau-0.
|
|
108
|
-
opentau-0.
|
|
107
|
+
opentau-0.2.0.dist-info/licenses/LICENSE,sha256=tl3_NkxplsgU86xSvEWnDlE1UR_JsIvGo7t4hPtsIbE,27680
|
|
108
|
+
opentau-0.2.0.dist-info/METADATA,sha256=OUN-KCCeTA1cFdlzmwRKecxepcmjpie6-AbbIkoELhI,12992
|
|
109
|
+
opentau-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
110
|
+
opentau-0.2.0.dist-info/entry_points.txt,sha256=NGF_MWpSKri0lvjR9WGN4pBUap8B-z21f7XMluxc1M4,208
|
|
111
|
+
opentau-0.2.0.dist-info/top_level.txt,sha256=7_yrS4x5KSeTRr2LICTCNOZmF-1_kSOFPKHvtJPL1Dw,8
|
|
112
|
+
opentau-0.2.0.dist-info/RECORD,,
|