opentau 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/configs/default.py +16 -0
- opentau/configs/deployment.py +85 -0
- opentau/configs/train.py +5 -0
- opentau/datasets/factory.py +43 -10
- opentau/datasets/lerobot_dataset.py +19 -19
- opentau/datasets/video_utils.py +11 -6
- opentau/policies/pi05/configuration_pi05.py +9 -6
- opentau/policies/pi05/modeling_pi05.py +296 -30
- opentau/policies/pi05/paligemma_with_expert.py +20 -20
- opentau/scripts/grpc/__init__.py +19 -0
- opentau/scripts/grpc/client.py +601 -0
- opentau/scripts/grpc/robot_inference_pb2.py +61 -0
- opentau/scripts/grpc/robot_inference_pb2_grpc.py +210 -0
- opentau/scripts/grpc/server.py +313 -0
- opentau/scripts/launch.py +12 -4
- opentau/scripts/train.py +94 -17
- opentau/scripts/visualize_dataset.py +141 -38
- opentau/utils/transformers_patch.py +251 -20
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/METADATA +37 -17
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/RECORD +24 -21
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/WHEEL +1 -1
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/entry_points.txt +1 -0
- opentau/scripts/libero_simulation_parallel.py +0 -356
- opentau/scripts/libero_simulation_sequential.py +0 -122
- opentau/scripts/visualize_dataset_html.py +0 -507
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/top_level.txt +0 -0
opentau/scripts/train.py
CHANGED
|
@@ -17,6 +17,9 @@
|
|
|
17
17
|
import json
|
|
18
18
|
import logging
|
|
19
19
|
import os
|
|
20
|
+
|
|
21
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
22
|
+
|
|
20
23
|
from contextlib import nullcontext
|
|
21
24
|
from pprint import pformat
|
|
22
25
|
from typing import Any
|
|
@@ -176,7 +179,10 @@ def train(cfg: TrainPipelineConfig):
|
|
|
176
179
|
logging.info("Anomaly detection is disabled.")
|
|
177
180
|
|
|
178
181
|
logging.info("Creating dataset")
|
|
179
|
-
|
|
182
|
+
if cfg.val_freq > 0:
|
|
183
|
+
train_dataset, val_dataset = make_dataset_mixture(cfg)
|
|
184
|
+
else:
|
|
185
|
+
train_dataset = make_dataset_mixture(cfg)
|
|
180
186
|
|
|
181
187
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
|
182
188
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
|
@@ -188,7 +194,7 @@ def train(cfg: TrainPipelineConfig):
|
|
|
188
194
|
)
|
|
189
195
|
|
|
190
196
|
logging.info("Creating policy")
|
|
191
|
-
policy = make_policy(cfg=cfg.policy, ds_meta=
|
|
197
|
+
policy = make_policy(cfg=cfg.policy, ds_meta=train_dataset.meta)
|
|
192
198
|
policy.to(torch.bfloat16)
|
|
193
199
|
logging.info("Creating optimizer and scheduler")
|
|
194
200
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
|
@@ -203,11 +209,18 @@ def train(cfg: TrainPipelineConfig):
|
|
|
203
209
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
|
204
210
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
|
205
211
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
212
|
+
if cfg.val_freq > 0:
|
|
213
|
+
train_dataloader = train_dataset.get_dataloader()
|
|
214
|
+
val_dataloader = val_dataset.get_dataloader()
|
|
215
|
+
policy, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
|
|
216
|
+
policy, optimizer, train_dataloader, val_dataloader, lr_scheduler
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
train_dataloader = train_dataset.get_dataloader()
|
|
220
|
+
policy, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
221
|
+
policy, optimizer, train_dataloader, lr_scheduler
|
|
222
|
+
)
|
|
223
|
+
train_dl_iter = cycle(train_dataloader)
|
|
211
224
|
|
|
212
225
|
# Register the LR scheduler for checkpointing
|
|
213
226
|
accelerator.register_for_checkpointing(lr_scheduler)
|
|
@@ -246,7 +259,7 @@ def train(cfg: TrainPipelineConfig):
|
|
|
246
259
|
for _ in range(cfg.gradient_accumulation_steps):
|
|
247
260
|
with accelerator.accumulate(policy) if cfg.gradient_accumulation_steps > 1 else nullcontext():
|
|
248
261
|
logging.debug(f"{step=}, {accelerator.sync_gradients=}")
|
|
249
|
-
batch = next(
|
|
262
|
+
batch = next(train_dl_iter)
|
|
250
263
|
|
|
251
264
|
train_tracker = update_policy(
|
|
252
265
|
cfg,
|
|
@@ -266,20 +279,21 @@ def train(cfg: TrainPipelineConfig):
|
|
|
266
279
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
|
267
280
|
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and cfg.save_checkpoint
|
|
268
281
|
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
|
282
|
+
is_val_step = cfg.val_freq > 0 and step % cfg.val_freq == 0
|
|
269
283
|
|
|
270
284
|
# Only `train_tracker` on the main process keeps useful statistics,
|
|
271
285
|
# because we guarded it with if accelerator.is_main_process in the `update_policy` function.
|
|
272
286
|
if is_log_step and accelerator.is_main_process:
|
|
273
287
|
logging.info(train_tracker)
|
|
274
288
|
log_dict = train_tracker.to_dict(use_avg=True)
|
|
275
|
-
accelerator.log({"Training
|
|
276
|
-
accelerator.log({"MSE Loss": log_dict["mse_loss"]}, step=step)
|
|
277
|
-
accelerator.log({"CE Loss": log_dict["ce_loss"]}, step=step)
|
|
278
|
-
accelerator.log({"L1 Loss": log_dict["l1_loss"]}, step=step)
|
|
279
|
-
accelerator.log({"Accuracy": log_dict["accuracy"]}, step=step)
|
|
280
|
-
accelerator.log({"Learning Rate": log_dict["lr"]}, step=step)
|
|
281
|
-
accelerator.log({"Grad Norm": log_dict["grad_norm"]}, step=step)
|
|
282
|
-
accelerator.log({"Num Samples": log_dict["samples"]}, step=step)
|
|
289
|
+
accelerator.log({"Training/Loss": log_dict["loss"]}, step=step)
|
|
290
|
+
accelerator.log({"Training/MSE Loss": log_dict["mse_loss"]}, step=step)
|
|
291
|
+
accelerator.log({"Training/CE Loss": log_dict["ce_loss"]}, step=step)
|
|
292
|
+
accelerator.log({"Training/L1 Loss": log_dict["l1_loss"]}, step=step)
|
|
293
|
+
accelerator.log({"Training/Accuracy": log_dict["accuracy"]}, step=step)
|
|
294
|
+
accelerator.log({"Training/Learning Rate": log_dict["lr"]}, step=step)
|
|
295
|
+
accelerator.log({"Training/Grad Norm": log_dict["grad_norm"]}, step=step)
|
|
296
|
+
accelerator.log({"Training/Num Samples": log_dict["samples"]}, step=step)
|
|
283
297
|
train_tracker.reset_averages()
|
|
284
298
|
|
|
285
299
|
if is_saving_step:
|
|
@@ -299,6 +313,70 @@ def train(cfg: TrainPipelineConfig):
|
|
|
299
313
|
if cfg.last_checkpoint_only:
|
|
300
314
|
prune_old_checkpoints(checkpoint_dir)
|
|
301
315
|
|
|
316
|
+
accelerator.wait_for_everyone()
|
|
317
|
+
|
|
318
|
+
if is_val_step:
|
|
319
|
+
policy.eval()
|
|
320
|
+
val_metrics = {
|
|
321
|
+
"loss": AverageMeter("val_total_loss", ":.3f"),
|
|
322
|
+
"mse_loss": AverageMeter("val_mse_loss", ":.3f"),
|
|
323
|
+
"ce_loss": AverageMeter("val_ce_loss", ":.3f"),
|
|
324
|
+
"l1_loss": AverageMeter("val_l1_loss", ":.3f"),
|
|
325
|
+
"accuracy": AverageMeter("val_accuracy", ":.3f"),
|
|
326
|
+
}
|
|
327
|
+
val_tracker = MetricsTracker(
|
|
328
|
+
cfg.batch_size * accelerator.num_processes,
|
|
329
|
+
val_metrics,
|
|
330
|
+
initial_step=step,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
logging.info(f"Validation at step {step}...")
|
|
334
|
+
|
|
335
|
+
with torch.no_grad():
|
|
336
|
+
for batch in val_dataloader:
|
|
337
|
+
losses = policy.forward(batch)
|
|
338
|
+
loss = cfg.loss_weighting["MSE"] * losses["MSE"] + cfg.loss_weighting["CE"] * losses["CE"]
|
|
339
|
+
|
|
340
|
+
# Gather and average metrics across processes
|
|
341
|
+
_first_loss_tensor = next(lt for lt in losses.values() if isinstance(lt, torch.Tensor))
|
|
342
|
+
zero = torch.tensor(0.0, device=_first_loss_tensor.device, dtype=_first_loss_tensor.dtype)
|
|
343
|
+
|
|
344
|
+
loss = accelerator.gather_for_metrics(loss).mean().item()
|
|
345
|
+
mse_loss = (
|
|
346
|
+
accelerator.gather_for_metrics(losses["MSE"]).to(dtype=torch.float32).mean().item()
|
|
347
|
+
)
|
|
348
|
+
ce_loss = (
|
|
349
|
+
accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item()
|
|
350
|
+
)
|
|
351
|
+
l1_loss = (
|
|
352
|
+
accelerator.gather_for_metrics(losses.get("L1", zero))
|
|
353
|
+
.to(dtype=torch.float32)
|
|
354
|
+
.mean()
|
|
355
|
+
.item()
|
|
356
|
+
)
|
|
357
|
+
accuracy = (
|
|
358
|
+
accelerator.gather_for_metrics(losses.get("Accuracy", zero))
|
|
359
|
+
.to(dtype=torch.float32)
|
|
360
|
+
.mean()
|
|
361
|
+
.item()
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if accelerator.is_main_process:
|
|
365
|
+
val_tracker.loss = loss
|
|
366
|
+
val_tracker.mse_loss = mse_loss
|
|
367
|
+
val_tracker.ce_loss = ce_loss
|
|
368
|
+
val_tracker.l1_loss = l1_loss
|
|
369
|
+
val_tracker.accuracy = accuracy
|
|
370
|
+
|
|
371
|
+
if accelerator.is_main_process:
|
|
372
|
+
logging.info(val_tracker)
|
|
373
|
+
val_dict = val_tracker.to_dict(use_avg=True)
|
|
374
|
+
accelerator.log({"Validation/Loss": val_dict["loss"]}, step=step)
|
|
375
|
+
accelerator.log({"Validation/MSE Loss": val_dict["mse_loss"]}, step=step)
|
|
376
|
+
accelerator.log({"Validation/CE Loss": val_dict["ce_loss"]}, step=step)
|
|
377
|
+
accelerator.log({"Validation/L1 Loss": val_dict["l1_loss"]}, step=step)
|
|
378
|
+
accelerator.log({"Validation/Accuracy": val_dict["accuracy"]}, step=step)
|
|
379
|
+
|
|
302
380
|
# This barrier is probably necessary to ensure
|
|
303
381
|
# other processes wait for the main process to finish saving
|
|
304
382
|
accelerator.wait_for_everyone()
|
|
@@ -357,7 +435,6 @@ def train(cfg: TrainPipelineConfig):
|
|
|
357
435
|
with open(videos_dir / "eval_info.json", "w") as f:
|
|
358
436
|
json.dump(eval_info, f, indent=2)
|
|
359
437
|
|
|
360
|
-
if is_eval_step:
|
|
361
438
|
# This barrier is to ensure all processes finishes evaluation before the next training step
|
|
362
439
|
# Some processes might be slower than others
|
|
363
440
|
accelerator.wait_for_everyone()
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
15
|
# See the License for the specific language governing permissions and
|
|
16
16
|
# limitations under the License.
|
|
17
|
-
"""
|
|
17
|
+
"""Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
|
18
18
|
|
|
19
19
|
Note: The last frame of the episode doesn't always correspond to a final state.
|
|
20
20
|
That's because our datasets are composed of transition from state to state up to
|
|
@@ -30,34 +30,21 @@ Examples:
|
|
|
30
30
|
|
|
31
31
|
- Visualize data stored on a local machine:
|
|
32
32
|
```
|
|
33
|
-
local$
|
|
34
|
-
--repo-id lerobot/pusht \
|
|
35
|
-
--episode-index 0
|
|
33
|
+
local$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0
|
|
36
34
|
```
|
|
37
35
|
|
|
38
36
|
- Visualize data stored on a distant machine with a local viewer:
|
|
39
37
|
```
|
|
40
|
-
distant$
|
|
41
|
-
--repo-id lerobot/pusht \
|
|
42
|
-
--episode-index 0 \
|
|
43
|
-
--save 1 \
|
|
44
|
-
--output-dir path/to/directory
|
|
38
|
+
distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --save 1 --output-dir path/to/directory
|
|
45
39
|
|
|
46
40
|
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
|
|
47
41
|
local$ rerun lerobot_pusht_episode_0.rrd
|
|
48
42
|
```
|
|
49
43
|
|
|
50
44
|
- Visualize data stored on a distant machine through streaming:
|
|
51
|
-
(You need to forward the websocket port to the distant machine, with
|
|
52
|
-
`ssh -L 9087:localhost:9087 username@remote-host`)
|
|
53
45
|
```
|
|
54
|
-
distant$ python src/opentau/scripts/visualize_dataset.py \
|
|
55
|
-
--repo-id lerobot/pusht \
|
|
56
|
-
--episode-index 0 \
|
|
57
|
-
--mode distant \
|
|
58
|
-
--ws-port 9087
|
|
59
46
|
|
|
60
|
-
|
|
47
|
+
distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --mode distant --web-port 9090
|
|
61
48
|
```
|
|
62
49
|
|
|
63
50
|
"""
|
|
@@ -65,7 +52,9 @@ local$ rerun ws://localhost:9087
|
|
|
65
52
|
import argparse
|
|
66
53
|
import gc
|
|
67
54
|
import logging
|
|
55
|
+
import os
|
|
68
56
|
import time
|
|
57
|
+
import warnings
|
|
69
58
|
from pathlib import Path
|
|
70
59
|
from typing import Iterator
|
|
71
60
|
|
|
@@ -75,8 +64,80 @@ import torch
|
|
|
75
64
|
import torch.utils.data
|
|
76
65
|
import tqdm
|
|
77
66
|
|
|
67
|
+
from opentau.configs.default import DatasetMixtureConfig, WandBConfig
|
|
68
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
78
69
|
from opentau.datasets.lerobot_dataset import LeRobotDataset
|
|
79
|
-
|
|
70
|
+
|
|
71
|
+
PERMIT_URDF = hasattr(rr, "urdf")
|
|
72
|
+
if not PERMIT_URDF:
|
|
73
|
+
warnings.warn(
|
|
74
|
+
"`rerun.urdf` module not found. Make sure you have rerun >= 0.28.2 installed. "
|
|
75
|
+
" One way to ensure this is to install OpenTau with the '[urdf]' extra: `pip install opentau[urdf]`.",
|
|
76
|
+
stacklevel=2,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Older and newer versions of rerun have different APIs for setting time / sequence
|
|
81
|
+
def _rr_set_sequence(timeline: str, value: int):
|
|
82
|
+
if hasattr(rr, "set_time_sequence"):
|
|
83
|
+
rr.set_time_sequence(timeline, value)
|
|
84
|
+
else:
|
|
85
|
+
rr.set_time(timeline, sequence=value)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _rr_set_seconds(timeline: str, value: float):
|
|
89
|
+
if hasattr(rr, "set_time_seconds"):
|
|
90
|
+
rr.set_time_seconds(timeline, value)
|
|
91
|
+
else:
|
|
92
|
+
rr.set_time(timeline, timestamp=value)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _rr_scalar(value: float):
|
|
96
|
+
"""Return a rerun scalar archetype that works across rerun versions.
|
|
97
|
+
|
|
98
|
+
Older rerun versions expose `rr.Scalar`, while newer versions expose `rr.Scalars`.
|
|
99
|
+
This wrapper returns an object suitable for `rr.log(path, ...)` for a single value.
|
|
100
|
+
"""
|
|
101
|
+
v = float(value)
|
|
102
|
+
|
|
103
|
+
# New API (plural archetype)
|
|
104
|
+
if hasattr(rr, "Scalars"):
|
|
105
|
+
try:
|
|
106
|
+
return rr.Scalars(v)
|
|
107
|
+
except TypeError:
|
|
108
|
+
# Some versions expect a sequence/array for Scalars.
|
|
109
|
+
return rr.Scalars([v])
|
|
110
|
+
|
|
111
|
+
# Old API
|
|
112
|
+
if hasattr(rr, "Scalar"):
|
|
113
|
+
return rr.Scalar(v)
|
|
114
|
+
|
|
115
|
+
raise AttributeError("rerun has neither `Scalar` nor `Scalars` - please upgrade `rerun-sdk`.")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def create_mock_train_config() -> TrainPipelineConfig:
|
|
119
|
+
"""Create a mock TrainPipelineConfig for dataset visualization.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
TrainPipelineConfig: A mock config with default values.
|
|
123
|
+
"""
|
|
124
|
+
return TrainPipelineConfig(
|
|
125
|
+
dataset_mixture=DatasetMixtureConfig(), # Will be set by the dataset
|
|
126
|
+
resolution=(224, 224),
|
|
127
|
+
num_cams=2,
|
|
128
|
+
max_state_dim=32,
|
|
129
|
+
max_action_dim=32,
|
|
130
|
+
action_chunk=50,
|
|
131
|
+
loss_weighting={"MSE": 1, "CE": 1},
|
|
132
|
+
num_workers=4,
|
|
133
|
+
batch_size=8,
|
|
134
|
+
steps=100_000,
|
|
135
|
+
log_freq=200,
|
|
136
|
+
save_checkpoint=True,
|
|
137
|
+
save_freq=20_000,
|
|
138
|
+
use_policy_training_preset=True,
|
|
139
|
+
wandb=WandBConfig(),
|
|
140
|
+
)
|
|
80
141
|
|
|
81
142
|
|
|
82
143
|
class EpisodeSampler(torch.utils.data.Sampler):
|
|
@@ -108,9 +169,9 @@ def visualize_dataset(
|
|
|
108
169
|
num_workers: int = 0,
|
|
109
170
|
mode: str = "local",
|
|
110
171
|
web_port: int = 9090,
|
|
111
|
-
ws_port: int = 9087,
|
|
112
172
|
save: bool = False,
|
|
113
173
|
output_dir: Path | None = None,
|
|
174
|
+
urdf: Path | None = None,
|
|
114
175
|
) -> Path | None:
|
|
115
176
|
if save:
|
|
116
177
|
assert output_dir is not None, (
|
|
@@ -141,16 +202,27 @@ def visualize_dataset(
|
|
|
141
202
|
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
|
142
203
|
gc.collect()
|
|
143
204
|
|
|
205
|
+
if urdf:
|
|
206
|
+
rr.log_file_from_path(urdf, static=True)
|
|
207
|
+
urdf_tree = rr.urdf.UrdfTree.from_file_path(urdf)
|
|
208
|
+
urdf_joints = [jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"]
|
|
209
|
+
print(
|
|
210
|
+
"Assuming the dataset state dimensions correspond to URDF joints in order:\n",
|
|
211
|
+
"\n".join(f"{i:3d}: {jnt.name}" for i, jnt in enumerate(urdf_joints)),
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
urdf_joints = []
|
|
215
|
+
|
|
144
216
|
if mode == "distant":
|
|
145
|
-
rr.
|
|
217
|
+
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
|
146
218
|
|
|
147
219
|
logging.info("Logging to Rerun")
|
|
148
220
|
|
|
149
221
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
|
150
222
|
# iterate over the batch
|
|
151
223
|
for i in range(len(batch["index"])):
|
|
152
|
-
|
|
153
|
-
|
|
224
|
+
_rr_set_sequence("frame_index", batch["frame_index"][i].item())
|
|
225
|
+
_rr_set_seconds("timestamp", batch["timestamp"][i].item())
|
|
154
226
|
|
|
155
227
|
# display each camera image
|
|
156
228
|
for key in dataset.meta.camera_keys:
|
|
@@ -160,21 +232,27 @@ def visualize_dataset(
|
|
|
160
232
|
# display each dimension of action space (e.g. actuators command)
|
|
161
233
|
if "action" in batch:
|
|
162
234
|
for dim_idx, val in enumerate(batch["action"][i]):
|
|
163
|
-
rr.log(f"action/{dim_idx}",
|
|
235
|
+
rr.log(f"action/{dim_idx}", _rr_scalar(val.item()))
|
|
164
236
|
|
|
165
237
|
# display each dimension of observed state space (e.g. agent position in joint space)
|
|
166
238
|
if "observation.state" in batch:
|
|
167
239
|
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
|
168
|
-
rr.log(f"state/{dim_idx}",
|
|
240
|
+
rr.log(f"state/{dim_idx}", _rr_scalar(val.item()))
|
|
241
|
+
# Assuming the state dimensions correspond to URDF joints in order.
|
|
242
|
+
# TODO(shuheng): allow overriding with a mapping from state dim to joint name.
|
|
243
|
+
if dim_idx < len(urdf_joints):
|
|
244
|
+
joint = urdf_joints[dim_idx]
|
|
245
|
+
transform = joint.compute_transform(float(val))
|
|
246
|
+
rr.log("URDF", transform)
|
|
169
247
|
|
|
170
248
|
if "next.done" in batch:
|
|
171
|
-
rr.log("next.done",
|
|
249
|
+
rr.log("next.done", _rr_scalar(batch["next.done"][i].item()))
|
|
172
250
|
|
|
173
251
|
if "next.reward" in batch:
|
|
174
|
-
rr.log("next.reward",
|
|
252
|
+
rr.log("next.reward", _rr_scalar(batch["next.reward"][i].item()))
|
|
175
253
|
|
|
176
254
|
if "next.success" in batch:
|
|
177
|
-
rr.log("next.success",
|
|
255
|
+
rr.log("next.success", _rr_scalar(batch["next.success"][i].item()))
|
|
178
256
|
|
|
179
257
|
if mode == "local" and save:
|
|
180
258
|
# save .rrd locally
|
|
@@ -194,7 +272,7 @@ def visualize_dataset(
|
|
|
194
272
|
print("Ctrl-C received. Exiting.")
|
|
195
273
|
|
|
196
274
|
|
|
197
|
-
def
|
|
275
|
+
def parse_args() -> dict:
|
|
198
276
|
parser = argparse.ArgumentParser()
|
|
199
277
|
|
|
200
278
|
parser.add_argument(
|
|
@@ -250,12 +328,6 @@ def main():
|
|
|
250
328
|
default=9090,
|
|
251
329
|
help="Web port for rerun.io when `--mode distant` is set.",
|
|
252
330
|
)
|
|
253
|
-
parser.add_argument(
|
|
254
|
-
"--ws-port",
|
|
255
|
-
type=int,
|
|
256
|
-
default=9087,
|
|
257
|
-
help="Web socket port for rerun.io when `--mode distant` is set.",
|
|
258
|
-
)
|
|
259
331
|
parser.add_argument(
|
|
260
332
|
"--save",
|
|
261
333
|
type=int,
|
|
@@ -266,7 +338,6 @@ def main():
|
|
|
266
338
|
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
|
267
339
|
),
|
|
268
340
|
)
|
|
269
|
-
|
|
270
341
|
parser.add_argument(
|
|
271
342
|
"--tolerance-s",
|
|
272
343
|
type=float,
|
|
@@ -277,17 +348,49 @@ def main():
|
|
|
277
348
|
"If not given, defaults to 1e-4."
|
|
278
349
|
),
|
|
279
350
|
)
|
|
351
|
+
parser.add_argument(
|
|
352
|
+
"--urdf",
|
|
353
|
+
type=Path,
|
|
354
|
+
default=None,
|
|
355
|
+
help="Path to a URDF file to load and visualize alongside the dataset.",
|
|
356
|
+
)
|
|
357
|
+
parser.add_argument(
|
|
358
|
+
"--urdf-package-dir",
|
|
359
|
+
type=Path,
|
|
360
|
+
default=None,
|
|
361
|
+
help=(
|
|
362
|
+
"Root directory of the URDF package to resolve package:// paths. "
|
|
363
|
+
"You can also set the ROS_PACKAGE_PATH environment variable, "
|
|
364
|
+
"which will be used if this argument is not provided."
|
|
365
|
+
),
|
|
366
|
+
)
|
|
280
367
|
|
|
281
368
|
args = parser.parse_args()
|
|
282
|
-
|
|
369
|
+
return vars(args)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def main():
|
|
373
|
+
kwargs = parse_args()
|
|
283
374
|
repo_id = kwargs.pop("repo_id")
|
|
284
375
|
root = kwargs.pop("root")
|
|
285
376
|
tolerance_s = kwargs.pop("tolerance_s")
|
|
377
|
+
urdf_package_dir = kwargs.pop("urdf_package_dir")
|
|
378
|
+
if urdf_package_dir:
|
|
379
|
+
os.environ["ROS_PACKAGE_PATH"] = urdf_package_dir.resolve().as_posix()
|
|
380
|
+
|
|
381
|
+
if not PERMIT_URDF:
|
|
382
|
+
kwargs["urdf"] = None
|
|
286
383
|
|
|
287
384
|
logging.info("Loading dataset")
|
|
288
|
-
dataset = LeRobotDataset(
|
|
385
|
+
dataset = LeRobotDataset(
|
|
386
|
+
create_mock_train_config(),
|
|
387
|
+
repo_id,
|
|
388
|
+
root=root,
|
|
389
|
+
tolerance_s=tolerance_s,
|
|
390
|
+
standardize=False,
|
|
391
|
+
)
|
|
289
392
|
|
|
290
|
-
visualize_dataset(dataset, **
|
|
393
|
+
visualize_dataset(dataset, **kwargs)
|
|
291
394
|
|
|
292
395
|
|
|
293
396
|
if __name__ == "__main__":
|