opentau 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
opentau/scripts/train.py CHANGED
@@ -17,6 +17,9 @@
17
17
  import json
18
18
  import logging
19
19
  import os
20
+
21
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
+
20
23
  from contextlib import nullcontext
21
24
  from pprint import pformat
22
25
  from typing import Any
@@ -176,7 +179,10 @@ def train(cfg: TrainPipelineConfig):
176
179
  logging.info("Anomaly detection is disabled.")
177
180
 
178
181
  logging.info("Creating dataset")
179
- dataset = make_dataset_mixture(cfg)
182
+ if cfg.val_freq > 0:
183
+ train_dataset, val_dataset = make_dataset_mixture(cfg)
184
+ else:
185
+ train_dataset = make_dataset_mixture(cfg)
180
186
 
181
187
  # Create environment used for evaluating checkpoints during training on simulation data.
182
188
  # On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -188,7 +194,7 @@ def train(cfg: TrainPipelineConfig):
188
194
  )
189
195
 
190
196
  logging.info("Creating policy")
191
- policy = make_policy(cfg=cfg.policy, ds_meta=dataset.meta)
197
+ policy = make_policy(cfg=cfg.policy, ds_meta=train_dataset.meta)
192
198
  policy.to(torch.bfloat16)
193
199
  logging.info("Creating optimizer and scheduler")
194
200
  optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@@ -203,11 +209,18 @@ def train(cfg: TrainPipelineConfig):
203
209
  logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
204
210
  logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
205
211
 
206
- dataloader = dataset.get_dataloader()
207
- policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
208
- policy, optimizer, dataloader, lr_scheduler
209
- )
210
- dl_iter = cycle(dataloader)
212
+ if cfg.val_freq > 0:
213
+ train_dataloader = train_dataset.get_dataloader()
214
+ val_dataloader = val_dataset.get_dataloader()
215
+ policy, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
216
+ policy, optimizer, train_dataloader, val_dataloader, lr_scheduler
217
+ )
218
+ else:
219
+ train_dataloader = train_dataset.get_dataloader()
220
+ policy, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
221
+ policy, optimizer, train_dataloader, lr_scheduler
222
+ )
223
+ train_dl_iter = cycle(train_dataloader)
211
224
 
212
225
  # Register the LR scheduler for checkpointing
213
226
  accelerator.register_for_checkpointing(lr_scheduler)
@@ -246,7 +259,7 @@ def train(cfg: TrainPipelineConfig):
246
259
  for _ in range(cfg.gradient_accumulation_steps):
247
260
  with accelerator.accumulate(policy) if cfg.gradient_accumulation_steps > 1 else nullcontext():
248
261
  logging.debug(f"{step=}, {accelerator.sync_gradients=}")
249
- batch = next(dl_iter)
262
+ batch = next(train_dl_iter)
250
263
 
251
264
  train_tracker = update_policy(
252
265
  cfg,
@@ -266,20 +279,21 @@ def train(cfg: TrainPipelineConfig):
266
279
  is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
267
280
  is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and cfg.save_checkpoint
268
281
  is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
282
+ is_val_step = cfg.val_freq > 0 and step % cfg.val_freq == 0
269
283
 
270
284
  # Only `train_tracker` on the main process keeps useful statistics,
271
285
  # because we guarded it with if accelerator.is_main_process in the `update_policy` function.
272
286
  if is_log_step and accelerator.is_main_process:
273
287
  logging.info(train_tracker)
274
288
  log_dict = train_tracker.to_dict(use_avg=True)
275
- accelerator.log({"Training Loss": log_dict["loss"]}, step=step)
276
- accelerator.log({"MSE Loss": log_dict["mse_loss"]}, step=step)
277
- accelerator.log({"CE Loss": log_dict["ce_loss"]}, step=step)
278
- accelerator.log({"L1 Loss": log_dict["l1_loss"]}, step=step)
279
- accelerator.log({"Accuracy": log_dict["accuracy"]}, step=step)
280
- accelerator.log({"Learning Rate": log_dict["lr"]}, step=step)
281
- accelerator.log({"Grad Norm": log_dict["grad_norm"]}, step=step)
282
- accelerator.log({"Num Samples": log_dict["samples"]}, step=step)
289
+ accelerator.log({"Training/Loss": log_dict["loss"]}, step=step)
290
+ accelerator.log({"Training/MSE Loss": log_dict["mse_loss"]}, step=step)
291
+ accelerator.log({"Training/CE Loss": log_dict["ce_loss"]}, step=step)
292
+ accelerator.log({"Training/L1 Loss": log_dict["l1_loss"]}, step=step)
293
+ accelerator.log({"Training/Accuracy": log_dict["accuracy"]}, step=step)
294
+ accelerator.log({"Training/Learning Rate": log_dict["lr"]}, step=step)
295
+ accelerator.log({"Training/Grad Norm": log_dict["grad_norm"]}, step=step)
296
+ accelerator.log({"Training/Num Samples": log_dict["samples"]}, step=step)
283
297
  train_tracker.reset_averages()
284
298
 
285
299
  if is_saving_step:
@@ -299,6 +313,70 @@ def train(cfg: TrainPipelineConfig):
299
313
  if cfg.last_checkpoint_only:
300
314
  prune_old_checkpoints(checkpoint_dir)
301
315
 
316
+ accelerator.wait_for_everyone()
317
+
318
+ if is_val_step:
319
+ policy.eval()
320
+ val_metrics = {
321
+ "loss": AverageMeter("val_total_loss", ":.3f"),
322
+ "mse_loss": AverageMeter("val_mse_loss", ":.3f"),
323
+ "ce_loss": AverageMeter("val_ce_loss", ":.3f"),
324
+ "l1_loss": AverageMeter("val_l1_loss", ":.3f"),
325
+ "accuracy": AverageMeter("val_accuracy", ":.3f"),
326
+ }
327
+ val_tracker = MetricsTracker(
328
+ cfg.batch_size * accelerator.num_processes,
329
+ val_metrics,
330
+ initial_step=step,
331
+ )
332
+
333
+ logging.info(f"Validation at step {step}...")
334
+
335
+ with torch.no_grad():
336
+ for batch in val_dataloader:
337
+ losses = policy.forward(batch)
338
+ loss = cfg.loss_weighting["MSE"] * losses["MSE"] + cfg.loss_weighting["CE"] * losses["CE"]
339
+
340
+ # Gather and average metrics across processes
341
+ _first_loss_tensor = next(lt for lt in losses.values() if isinstance(lt, torch.Tensor))
342
+ zero = torch.tensor(0.0, device=_first_loss_tensor.device, dtype=_first_loss_tensor.dtype)
343
+
344
+ loss = accelerator.gather_for_metrics(loss).mean().item()
345
+ mse_loss = (
346
+ accelerator.gather_for_metrics(losses["MSE"]).to(dtype=torch.float32).mean().item()
347
+ )
348
+ ce_loss = (
349
+ accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item()
350
+ )
351
+ l1_loss = (
352
+ accelerator.gather_for_metrics(losses.get("L1", zero))
353
+ .to(dtype=torch.float32)
354
+ .mean()
355
+ .item()
356
+ )
357
+ accuracy = (
358
+ accelerator.gather_for_metrics(losses.get("Accuracy", zero))
359
+ .to(dtype=torch.float32)
360
+ .mean()
361
+ .item()
362
+ )
363
+
364
+ if accelerator.is_main_process:
365
+ val_tracker.loss = loss
366
+ val_tracker.mse_loss = mse_loss
367
+ val_tracker.ce_loss = ce_loss
368
+ val_tracker.l1_loss = l1_loss
369
+ val_tracker.accuracy = accuracy
370
+
371
+ if accelerator.is_main_process:
372
+ logging.info(val_tracker)
373
+ val_dict = val_tracker.to_dict(use_avg=True)
374
+ accelerator.log({"Validation/Loss": val_dict["loss"]}, step=step)
375
+ accelerator.log({"Validation/MSE Loss": val_dict["mse_loss"]}, step=step)
376
+ accelerator.log({"Validation/CE Loss": val_dict["ce_loss"]}, step=step)
377
+ accelerator.log({"Validation/L1 Loss": val_dict["l1_loss"]}, step=step)
378
+ accelerator.log({"Validation/Accuracy": val_dict["accuracy"]}, step=step)
379
+
302
380
  # This barrier is probably necessary to ensure
303
381
  # other processes wait for the main process to finish saving
304
382
  accelerator.wait_for_everyone()
@@ -357,7 +435,6 @@ def train(cfg: TrainPipelineConfig):
357
435
  with open(videos_dir / "eval_info.json", "w") as f:
358
436
  json.dump(eval_info, f, indent=2)
359
437
 
360
- if is_eval_step:
361
438
  # This barrier is to ensure all processes finishes evaluation before the next training step
362
439
  # Some processes might be slower than others
363
440
  accelerator.wait_for_everyone()
@@ -14,7 +14,7 @@
14
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
- """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
17
+ """Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
18
18
 
19
19
  Note: The last frame of the episode doesn't always correspond to a final state.
20
20
  That's because our datasets are composed of transition from state to state up to
@@ -30,34 +30,21 @@ Examples:
30
30
 
31
31
  - Visualize data stored on a local machine:
32
32
  ```
33
- local$ python src/opentau/scripts/visualize_dataset.py \
34
- --repo-id lerobot/pusht \
35
- --episode-index 0
33
+ local$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0
36
34
  ```
37
35
 
38
36
  - Visualize data stored on a distant machine with a local viewer:
39
37
  ```
40
- distant$ python src/opentau/scripts/visualize_dataset.py \
41
- --repo-id lerobot/pusht \
42
- --episode-index 0 \
43
- --save 1 \
44
- --output-dir path/to/directory
38
+ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --save 1 --output-dir path/to/directory
45
39
 
46
40
  local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
47
41
  local$ rerun lerobot_pusht_episode_0.rrd
48
42
  ```
49
43
 
50
44
  - Visualize data stored on a distant machine through streaming:
51
- (You need to forward the websocket port to the distant machine, with
52
- `ssh -L 9087:localhost:9087 username@remote-host`)
53
45
  ```
54
- distant$ python src/opentau/scripts/visualize_dataset.py \
55
- --repo-id lerobot/pusht \
56
- --episode-index 0 \
57
- --mode distant \
58
- --ws-port 9087
59
46
 
60
- local$ rerun ws://localhost:9087
47
+ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --mode distant --web-port 9090
61
48
  ```
62
49
 
63
50
  """
@@ -65,7 +52,9 @@ local$ rerun ws://localhost:9087
65
52
  import argparse
66
53
  import gc
67
54
  import logging
55
+ import os
68
56
  import time
57
+ import warnings
69
58
  from pathlib import Path
70
59
  from typing import Iterator
71
60
 
@@ -75,8 +64,80 @@ import torch
75
64
  import torch.utils.data
76
65
  import tqdm
77
66
 
67
+ from opentau.configs.default import DatasetMixtureConfig, WandBConfig
68
+ from opentau.configs.train import TrainPipelineConfig
78
69
  from opentau.datasets.lerobot_dataset import LeRobotDataset
79
- from opentau.scripts.visualize_dataset_html import create_mock_train_config
70
+
71
+ PERMIT_URDF = hasattr(rr, "urdf")
72
+ if not PERMIT_URDF:
73
+ warnings.warn(
74
+ "`rerun.urdf` module not found. Make sure you have rerun >= 0.28.2 installed. "
75
+ " One way to ensure this is to install OpenTau with the '[urdf]' extra: `pip install opentau[urdf]`.",
76
+ stacklevel=2,
77
+ )
78
+
79
+
80
+ # Older and newer versions of rerun have different APIs for setting time / sequence
81
+ def _rr_set_sequence(timeline: str, value: int):
82
+ if hasattr(rr, "set_time_sequence"):
83
+ rr.set_time_sequence(timeline, value)
84
+ else:
85
+ rr.set_time(timeline, sequence=value)
86
+
87
+
88
+ def _rr_set_seconds(timeline: str, value: float):
89
+ if hasattr(rr, "set_time_seconds"):
90
+ rr.set_time_seconds(timeline, value)
91
+ else:
92
+ rr.set_time(timeline, timestamp=value)
93
+
94
+
95
+ def _rr_scalar(value: float):
96
+ """Return a rerun scalar archetype that works across rerun versions.
97
+
98
+ Older rerun versions expose `rr.Scalar`, while newer versions expose `rr.Scalars`.
99
+ This wrapper returns an object suitable for `rr.log(path, ...)` for a single value.
100
+ """
101
+ v = float(value)
102
+
103
+ # New API (plural archetype)
104
+ if hasattr(rr, "Scalars"):
105
+ try:
106
+ return rr.Scalars(v)
107
+ except TypeError:
108
+ # Some versions expect a sequence/array for Scalars.
109
+ return rr.Scalars([v])
110
+
111
+ # Old API
112
+ if hasattr(rr, "Scalar"):
113
+ return rr.Scalar(v)
114
+
115
+ raise AttributeError("rerun has neither `Scalar` nor `Scalars` - please upgrade `rerun-sdk`.")
116
+
117
+
118
+ def create_mock_train_config() -> TrainPipelineConfig:
119
+ """Create a mock TrainPipelineConfig for dataset visualization.
120
+
121
+ Returns:
122
+ TrainPipelineConfig: A mock config with default values.
123
+ """
124
+ return TrainPipelineConfig(
125
+ dataset_mixture=DatasetMixtureConfig(), # Will be set by the dataset
126
+ resolution=(224, 224),
127
+ num_cams=2,
128
+ max_state_dim=32,
129
+ max_action_dim=32,
130
+ action_chunk=50,
131
+ loss_weighting={"MSE": 1, "CE": 1},
132
+ num_workers=4,
133
+ batch_size=8,
134
+ steps=100_000,
135
+ log_freq=200,
136
+ save_checkpoint=True,
137
+ save_freq=20_000,
138
+ use_policy_training_preset=True,
139
+ wandb=WandBConfig(),
140
+ )
80
141
 
81
142
 
82
143
  class EpisodeSampler(torch.utils.data.Sampler):
@@ -108,9 +169,9 @@ def visualize_dataset(
108
169
  num_workers: int = 0,
109
170
  mode: str = "local",
110
171
  web_port: int = 9090,
111
- ws_port: int = 9087,
112
172
  save: bool = False,
113
173
  output_dir: Path | None = None,
174
+ urdf: Path | None = None,
114
175
  ) -> Path | None:
115
176
  if save:
116
177
  assert output_dir is not None, (
@@ -141,16 +202,27 @@ def visualize_dataset(
141
202
  # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
142
203
  gc.collect()
143
204
 
205
+ if urdf:
206
+ rr.log_file_from_path(urdf, static=True)
207
+ urdf_tree = rr.urdf.UrdfTree.from_file_path(urdf)
208
+ urdf_joints = [jnt for jnt in urdf_tree.joints() if jnt.joint_type != "fixed"]
209
+ print(
210
+ "Assuming the dataset state dimensions correspond to URDF joints in order:\n",
211
+ "\n".join(f"{i:3d}: {jnt.name}" for i, jnt in enumerate(urdf_joints)),
212
+ )
213
+ else:
214
+ urdf_joints = []
215
+
144
216
  if mode == "distant":
145
- rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
217
+ rr.serve_web_viewer(open_browser=False, web_port=web_port)
146
218
 
147
219
  logging.info("Logging to Rerun")
148
220
 
149
221
  for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
150
222
  # iterate over the batch
151
223
  for i in range(len(batch["index"])):
152
- rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
153
- rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
224
+ _rr_set_sequence("frame_index", batch["frame_index"][i].item())
225
+ _rr_set_seconds("timestamp", batch["timestamp"][i].item())
154
226
 
155
227
  # display each camera image
156
228
  for key in dataset.meta.camera_keys:
@@ -160,21 +232,27 @@ def visualize_dataset(
160
232
  # display each dimension of action space (e.g. actuators command)
161
233
  if "action" in batch:
162
234
  for dim_idx, val in enumerate(batch["action"][i]):
163
- rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
235
+ rr.log(f"action/{dim_idx}", _rr_scalar(val.item()))
164
236
 
165
237
  # display each dimension of observed state space (e.g. agent position in joint space)
166
238
  if "observation.state" in batch:
167
239
  for dim_idx, val in enumerate(batch["observation.state"][i]):
168
- rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
240
+ rr.log(f"state/{dim_idx}", _rr_scalar(val.item()))
241
+ # Assuming the state dimensions correspond to URDF joints in order.
242
+ # TODO(shuheng): allow overriding with a mapping from state dim to joint name.
243
+ if dim_idx < len(urdf_joints):
244
+ joint = urdf_joints[dim_idx]
245
+ transform = joint.compute_transform(float(val))
246
+ rr.log("URDF", transform)
169
247
 
170
248
  if "next.done" in batch:
171
- rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
249
+ rr.log("next.done", _rr_scalar(batch["next.done"][i].item()))
172
250
 
173
251
  if "next.reward" in batch:
174
- rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
252
+ rr.log("next.reward", _rr_scalar(batch["next.reward"][i].item()))
175
253
 
176
254
  if "next.success" in batch:
177
- rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
255
+ rr.log("next.success", _rr_scalar(batch["next.success"][i].item()))
178
256
 
179
257
  if mode == "local" and save:
180
258
  # save .rrd locally
@@ -194,7 +272,7 @@ def visualize_dataset(
194
272
  print("Ctrl-C received. Exiting.")
195
273
 
196
274
 
197
- def main():
275
+ def parse_args() -> dict:
198
276
  parser = argparse.ArgumentParser()
199
277
 
200
278
  parser.add_argument(
@@ -250,12 +328,6 @@ def main():
250
328
  default=9090,
251
329
  help="Web port for rerun.io when `--mode distant` is set.",
252
330
  )
253
- parser.add_argument(
254
- "--ws-port",
255
- type=int,
256
- default=9087,
257
- help="Web socket port for rerun.io when `--mode distant` is set.",
258
- )
259
331
  parser.add_argument(
260
332
  "--save",
261
333
  type=int,
@@ -266,7 +338,6 @@ def main():
266
338
  "Visualize the data by running `rerun path/to/file.rrd` on your local machine."
267
339
  ),
268
340
  )
269
-
270
341
  parser.add_argument(
271
342
  "--tolerance-s",
272
343
  type=float,
@@ -277,17 +348,49 @@ def main():
277
348
  "If not given, defaults to 1e-4."
278
349
  ),
279
350
  )
351
+ parser.add_argument(
352
+ "--urdf",
353
+ type=Path,
354
+ default=None,
355
+ help="Path to a URDF file to load and visualize alongside the dataset.",
356
+ )
357
+ parser.add_argument(
358
+ "--urdf-package-dir",
359
+ type=Path,
360
+ default=None,
361
+ help=(
362
+ "Root directory of the URDF package to resolve package:// paths. "
363
+ "You can also set the ROS_PACKAGE_PATH environment variable, "
364
+ "which will be used if this argument is not provided."
365
+ ),
366
+ )
280
367
 
281
368
  args = parser.parse_args()
282
- kwargs = vars(args)
369
+ return vars(args)
370
+
371
+
372
+ def main():
373
+ kwargs = parse_args()
283
374
  repo_id = kwargs.pop("repo_id")
284
375
  root = kwargs.pop("root")
285
376
  tolerance_s = kwargs.pop("tolerance_s")
377
+ urdf_package_dir = kwargs.pop("urdf_package_dir")
378
+ if urdf_package_dir:
379
+ os.environ["ROS_PACKAGE_PATH"] = urdf_package_dir.resolve().as_posix()
380
+
381
+ if not PERMIT_URDF:
382
+ kwargs["urdf"] = None
286
383
 
287
384
  logging.info("Loading dataset")
288
- dataset = LeRobotDataset(create_mock_train_config(), repo_id, root=root, tolerance_s=tolerance_s)
385
+ dataset = LeRobotDataset(
386
+ create_mock_train_config(),
387
+ repo_id,
388
+ root=root,
389
+ tolerance_s=tolerance_s,
390
+ standardize=False,
391
+ )
289
392
 
290
- visualize_dataset(dataset, **vars(args))
393
+ visualize_dataset(dataset, **kwargs)
291
394
 
292
395
 
293
396
  if __name__ == "__main__":