opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,379 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import json
18
+ import logging
19
+ import os
20
+ from contextlib import nullcontext
21
+ from pprint import pformat
22
+ from typing import Any
23
+
24
+ import accelerate
25
+ import torch
26
+ from accelerate.optimizer import AcceleratedOptimizer
27
+ from accelerate.scheduler import AcceleratedScheduler
28
+ from accelerate.utils import DistributedDataParallelKwargs, gather_object
29
+ from termcolor import colored
30
+
31
+ from opentau.configs import parser
32
+ from opentau.configs.train import TrainPipelineConfig
33
+ from opentau.datasets.factory import make_dataset_mixture
34
+ from opentau.datasets.utils import cycle
35
+ from opentau.envs.factory import make_envs
36
+ from opentau.envs.utils import close_envs
37
+ from opentau.optim.factory import make_optimizer_and_scheduler
38
+ from opentau.policies.factory import make_policy
39
+ from opentau.policies.pretrained import PreTrainedPolicy
40
+ from opentau.scripts.eval import consolidate_eval_info, eval_policy_all
41
+ from opentau.utils.accelerate_utils import set_proc_accelerator
42
+ from opentau.utils.logging_utils import AverageMeter, MetricsTracker
43
+ from opentau.utils.random_utils import set_seed
44
+ from opentau.utils.train_utils import (
45
+ get_step_checkpoint_dir,
46
+ get_step_identifier,
47
+ load_training_state,
48
+ load_training_step,
49
+ prune_old_checkpoints,
50
+ save_checkpoint,
51
+ )
52
+ from opentau.utils.utils import (
53
+ encode_accelerator_state_dict,
54
+ format_big_number,
55
+ init_logging,
56
+ is_launched_with_accelerate,
57
+ )
58
+
59
+
60
+ def update_policy(
61
+ train_config: TrainPipelineConfig,
62
+ train_metrics: MetricsTracker,
63
+ policy: PreTrainedPolicy,
64
+ batch: Any,
65
+ optimizer: AcceleratedOptimizer,
66
+ grad_clip_norm: float,
67
+ accelerator: accelerate.Accelerator,
68
+ lr_scheduler: AcceleratedScheduler | None = None,
69
+ ) -> tuple[MetricsTracker, dict]:
70
+ policy.train()
71
+ losses = policy.forward(batch)
72
+ loss = (
73
+ train_config.loss_weighting["MSE"] * losses["MSE"] + train_config.loss_weighting["CE"] * losses["CE"]
74
+ )
75
+
76
+ # accelerator.backward(loss)
77
+ # accelerator.unscale_gradients(optimizer=optimizer)
78
+
79
+ # if accelerator.sync_gradients:
80
+ # grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
81
+ # if accelerator.is_main_process:
82
+ # train_metrics.grad_norm = grad_norm
83
+
84
+ # optimizer.step()
85
+ # optimizer.zero_grad()
86
+
87
+ # Step through pytorch scheduler at every batch instead of epoch
88
+ if lr_scheduler is not None:
89
+ lr_scheduler.step()
90
+
91
+ # This calls `torch.distributed.all_gather_into_tensor` under the hood, which is not so efficient.
92
+ # We don't actually want to broadcast the gathered tensors to all processes, but only to the main process.
93
+ # Nonetheless, we still do this for correctness, safety, and simplicity.
94
+ _first_loss_tensor = next(lt for lt in losses.values() if isinstance(lt, torch.Tensor))
95
+ zero = torch.tensor(0.0, device=_first_loss_tensor.device, dtype=_first_loss_tensor.dtype)
96
+ loss = accelerator.gather_for_metrics(loss).mean().item()
97
+ mse_loss = accelerator.gather_for_metrics(losses["MSE"]).to(dtype=torch.float32).mean().item()
98
+ ce_loss = accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item()
99
+ l1_loss = accelerator.gather_for_metrics(losses.get("L1", zero)).to(dtype=torch.float32).mean().item()
100
+ accuracy = (
101
+ accelerator.gather_for_metrics(losses.get("Accuracy", zero)).to(dtype=torch.float32).mean().item()
102
+ )
103
+ # This actually calls `.update` method of the `AverageMeter` class. This operation is not idempotent.
104
+ # See MetricsTracker.__setattr__ for more details.
105
+ # In other words, setting `train_metrics.loss = 1` and `train_metrics.loss = 2` consecutively results in
106
+ # an average of 1.5 when formatted as a string, not just 2.
107
+ if accelerator.is_main_process:
108
+ train_metrics.loss = loss
109
+ train_metrics.mse_loss = mse_loss
110
+ train_metrics.ce_loss = ce_loss
111
+ train_metrics.l1_loss = l1_loss
112
+ train_metrics.accuracy = accuracy
113
+ train_metrics.lr = optimizer.param_groups[0]["lr"]
114
+
115
+ return train_metrics
116
+
117
+
118
+ @parser.wrap()
119
+ def train(cfg: TrainPipelineConfig):
120
+ cfg.validate()
121
+
122
+ accelerator_kwargs = {
123
+ "step_scheduler_with_optimizer": False,
124
+ "split_batches": False, # split_batches == True is not working anyways
125
+ "kwargs_handlers": [DistributedDataParallelKwargs(find_unused_parameters=True)],
126
+ }
127
+ if cfg.wandb.enable:
128
+ accelerator_kwargs["log_with"] = "wandb"
129
+ if cfg.gradient_accumulation_steps > 1:
130
+ accelerator_kwargs["gradient_accumulation_steps"] = cfg.gradient_accumulation_steps
131
+
132
+ accelerator = accelerate.Accelerator(**accelerator_kwargs)
133
+ init_logging(accelerator, level=logging.DEBUG if cfg.debug else logging.INFO)
134
+ # Register accelerator globally for use in other modules, (e.g., detect current rank, etc.)
135
+ set_proc_accelerator(accelerator)
136
+
137
+ logging.info(pformat(cfg.to_dict()))
138
+
139
+ if accelerator.is_main_process:
140
+ accelerator_config = encode_accelerator_state_dict(accelerator.state.__dict__)
141
+ logging.info(pformat(accelerator_config))
142
+
143
+ # Ensure `gradient_accumulation_steps` is consistent between TrainPipelineConfig and DeepSpeedConfig
144
+ if accelerator.distributed_type == accelerate.DistributedType.DEEPSPEED:
145
+ deepspeed_config, deepspeed_key = accelerator.deepspeed_plugin.hf_ds_config.find_config_node(
146
+ "gradient_accumulation_steps"
147
+ )
148
+ ds_grad_acc_steps = deepspeed_config.get(deepspeed_key, 1)
149
+ if ds_grad_acc_steps != cfg.gradient_accumulation_steps:
150
+ raise ValueError(
151
+ "The `gradient_accumulation_steps` in TrainPipelineConfig does not match the value "
152
+ f"specified in DeepSpeedConfig {cfg.gradient_accumulation_steps} != {ds_grad_acc_steps}. " # nosec B608
153
+ )
154
+
155
+ if cfg.wandb.enable:
156
+ step = load_training_step(cfg.checkpoint_path) if cfg.resume else None
157
+ slurm_dict = {k: v for k, v in os.environ.items() if k.startswith("SLURM_")}
158
+ accelerator.init_trackers(
159
+ cfg.wandb.project,
160
+ config={**cfg.to_dict(), "accelerator": accelerator_config, "slurm": slurm_dict},
161
+ init_kwargs={"wandb": cfg.wandb.to_wandb_kwargs(step=step)},
162
+ )
163
+ tracker = accelerator.get_tracker("wandb", unwrap=True)
164
+ cfg.wandb.run_id = tracker.id
165
+ logging.info(f"tracker initialized with wandb job id: {tracker.id}")
166
+
167
+ if cfg.seed is not None:
168
+ set_seed(cfg.seed, accelerator=accelerator)
169
+
170
+ # Enable anomaly detection for debugging NaN/Inf values
171
+ # (warning: large computational overhead)
172
+ torch.autograd.set_detect_anomaly(cfg.trace_nans)
173
+ if cfg.trace_nans:
174
+ logging.warning("Anomaly detection is enabled. This may significantly slow down training.")
175
+ else:
176
+ logging.info("Anomaly detection is disabled.")
177
+
178
+ logging.info("Creating dataset")
179
+ dataset = make_dataset_mixture(cfg)
180
+
181
+ # Create environment used for evaluating checkpoints during training on simulation data.
182
+ # On real-world data, no need to create an environment as evaluations are done outside train.py,
183
+ eval_envs = None
184
+ if cfg.eval_freq > 0 and cfg.env is not None:
185
+ logging.info("Creating env")
186
+ eval_envs = make_envs(
187
+ cfg.env, cfg, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs
188
+ )
189
+
190
+ logging.info("Creating policy")
191
+ policy = make_policy(cfg=cfg.policy, ds_meta=dataset.meta)
192
+ policy.to(torch.bfloat16)
193
+ logging.info("Creating optimizer and scheduler")
194
+ optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
195
+
196
+ step = 0 # number of policy updates (forward + backward + optim)
197
+
198
+ if accelerator.is_main_process:
199
+ num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
200
+ num_total_params = sum(p.numel() for p in policy.parameters())
201
+ logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
202
+ logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
203
+ logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
204
+ logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
205
+
206
+ dataloader = dataset.get_dataloader()
207
+ policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
208
+ policy, optimizer, dataloader, lr_scheduler
209
+ )
210
+ dl_iter = cycle(dataloader)
211
+
212
+ # Register the LR scheduler for checkpointing
213
+ accelerator.register_for_checkpointing(lr_scheduler)
214
+
215
+ if cfg.resume:
216
+ # load accelerator state
217
+ # This will load the model, optimizer, and lr_scheduler state
218
+ accelerator.load_state(cfg.checkpoint_path)
219
+
220
+ # all processes should load the step & rng states
221
+ step = load_training_state(cfg.checkpoint_path)
222
+ logging.info(f"Resuming training from checkpoint {cfg.checkpoint_path}")
223
+
224
+ policy.train()
225
+
226
+ # setup metrics tracker to average metrics over the logging interval
227
+ train_metrics = {
228
+ "loss": AverageMeter("total_loss", ":.3f"),
229
+ "mse_loss": AverageMeter("mse_loss", ":.3f"),
230
+ "ce_loss": AverageMeter("ce_loss", ":.3f"),
231
+ "l1_loss": AverageMeter("l1_loss", ":.3f"),
232
+ "accuracy": AverageMeter("accuracy", ":.3f"),
233
+ "lr": AverageMeter("lr", ":0.1e"),
234
+ "grad_norm": AverageMeter("grad_norm", ":.3f"),
235
+ }
236
+ train_tracker = MetricsTracker(
237
+ cfg.batch_size * accelerator.num_processes, # split_batches are not working
238
+ train_metrics,
239
+ initial_step=step,
240
+ )
241
+
242
+ if accelerator.is_main_process:
243
+ logging.info("Start offline training on a fixed dataset")
244
+
245
+ for _ in range(step, cfg.steps):
246
+ for _ in range(cfg.gradient_accumulation_steps):
247
+ with accelerator.accumulate(policy) if cfg.gradient_accumulation_steps > 1 else nullcontext():
248
+ logging.debug(f"{step=}, {accelerator.sync_gradients=}")
249
+ batch = next(dl_iter)
250
+
251
+ train_tracker = update_policy(
252
+ cfg,
253
+ train_tracker,
254
+ policy,
255
+ batch,
256
+ optimizer,
257
+ cfg.optimizer.grad_clip_norm,
258
+ accelerator=accelerator,
259
+ lr_scheduler=lr_scheduler,
260
+ )
261
+
262
+ # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
263
+ # increment `step` here.
264
+ step += 1
265
+ train_tracker.step()
266
+ is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
267
+ is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and cfg.save_checkpoint
268
+ is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
269
+
270
+ # Only `train_tracker` on the main process keeps useful statistics,
271
+ # because we guarded it with if accelerator.is_main_process in the `update_policy` function.
272
+ if is_log_step and accelerator.is_main_process:
273
+ logging.info(train_tracker)
274
+ log_dict = train_tracker.to_dict(use_avg=True)
275
+ accelerator.log({"Training Loss": log_dict["loss"]}, step=step)
276
+ accelerator.log({"MSE Loss": log_dict["mse_loss"]}, step=step)
277
+ accelerator.log({"CE Loss": log_dict["ce_loss"]}, step=step)
278
+ accelerator.log({"L1 Loss": log_dict["l1_loss"]}, step=step)
279
+ accelerator.log({"Accuracy": log_dict["accuracy"]}, step=step)
280
+ accelerator.log({"Learning Rate": log_dict["lr"]}, step=step)
281
+ accelerator.log({"Grad Norm": log_dict["grad_norm"]}, step=step)
282
+ accelerator.log({"Num Samples": log_dict["samples"]}, step=step)
283
+ train_tracker.reset_averages()
284
+
285
+ if is_saving_step:
286
+ # TODO: investigate whether this barrier is needed
287
+ accelerator.wait_for_everyone()
288
+ checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
289
+
290
+ # save the accelerator state
291
+ # This will save the model, optimizer, and lr_scheduler state
292
+ accelerator.save_state(checkpoint_dir)
293
+
294
+ # save axillary objects such as configs, training step, and rng state
295
+ if accelerator.is_main_process:
296
+ logging.info(f"Checkpoint policy after step {step}")
297
+ cfg.policy.pretrained_path = checkpoint_dir
298
+ save_checkpoint(checkpoint_dir, step, cfg)
299
+ if cfg.last_checkpoint_only:
300
+ prune_old_checkpoints(checkpoint_dir)
301
+
302
+ # This barrier is probably necessary to ensure
303
+ # other processes wait for the main process to finish saving
304
+ accelerator.wait_for_everyone()
305
+
306
+ if is_eval_step and eval_envs:
307
+ step_id = get_step_identifier(step, cfg.steps)
308
+ logging.info(f"Eval policy at step {step}")
309
+ with (
310
+ torch.no_grad(),
311
+ torch.autocast(device_type=accelerator.device.type) if cfg.policy.use_amp else nullcontext(),
312
+ ):
313
+ eval_info = eval_policy_all(
314
+ eval_envs,
315
+ policy,
316
+ cfg.eval.n_episodes,
317
+ cfg,
318
+ videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
319
+ max_episodes_rendered=cfg.eval.max_episodes_rendered,
320
+ start_seed=cfg.seed,
321
+ max_parallel_tasks=cfg.env.max_parallel_tasks,
322
+ )
323
+
324
+ eval_info = gather_object([eval_info]) # gather across all accelerator processes
325
+ if accelerator.is_main_process:
326
+ eval_info = consolidate_eval_info(eval_info)
327
+ # overall metrics (suite-agnostic)
328
+ aggregated = eval_info["overall"]
329
+
330
+ # optional: per-suite logging
331
+ for suite, suite_info in eval_info.items():
332
+ logging.info("Suite %s aggregated: %s", suite, suite_info)
333
+
334
+ # meters/tracker
335
+ eval_metrics = {
336
+ "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
337
+ "pc_success": AverageMeter("success", ":.1f"),
338
+ "eval_per_gpu_s": AverageMeter("eval_per_gpu_s", ":.3f"),
339
+ }
340
+ eval_tracker = MetricsTracker(
341
+ cfg.batch_size,
342
+ eval_metrics,
343
+ initial_step=step,
344
+ )
345
+ eval_tracker.eval_per_gpu_s = aggregated.get("eval_per_gpu_s", float("nan"))
346
+ eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan"))
347
+ eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
348
+ logging.info(eval_tracker)
349
+ eval_dict = eval_tracker.to_dict(use_avg=True)
350
+ accelerator.log({"Success Rate": eval_dict["pc_success"]}, step=step)
351
+ accelerator.log({"Evaluation Time": eval_dict["eval_per_gpu_s"]}, step=step)
352
+ for group, v in eval_info["per_group"].items():
353
+ accelerator.log({f"Success/{group}": v["pc_success"]}, step=step)
354
+
355
+ # Save eval_info to the same directory as videos
356
+ videos_dir = cfg.output_dir / "eval" / f"videos_step_{step_id}"
357
+ with open(videos_dir / "eval_info.json", "w") as f:
358
+ json.dump(eval_info, f, indent=2)
359
+
360
+ if is_eval_step:
361
+ # This barrier is to ensure all processes finishes evaluation before the next training step
362
+ # Some processes might be slower than others
363
+ accelerator.wait_for_everyone()
364
+
365
+ if cfg.eval_freq > 0 and eval_envs:
366
+ close_envs(eval_envs)
367
+
368
+ accelerator.end_training()
369
+ if accelerator.is_main_process:
370
+ logging.info("End of training")
371
+
372
+
373
+ if __name__ == "__main__":
374
+ if not is_launched_with_accelerate():
375
+ raise Exception(
376
+ "This script should be launched with accelerate. Please use `accelerate launch` to run this script."
377
+ )
378
+
379
+ train()
@@ -0,0 +1,294 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
18
+
19
+ Note: The last frame of the episode doesn't always correspond to a final state.
20
+ That's because our datasets are composed of transition from state to state up to
21
+ the antepenultimate state associated to the ultimate action to arrive in the final state.
22
+ However, there might not be a transition from a final state to another state.
23
+
24
+ Note: This script aims to visualize the data used to train the neural networks.
25
+ ~What you see is what you get~. When visualizing image modality, it is often expected to observe
26
+ lossy compression artifacts since these images have been decoded from compressed mp4 videos to
27
+ save disk space. The compression factor applied has been tuned to not affect success rate.
28
+
29
+ Examples:
30
+
31
+ - Visualize data stored on a local machine:
32
+ ```
33
+ local$ python src/opentau/scripts/visualize_dataset.py \
34
+ --repo-id lerobot/pusht \
35
+ --episode-index 0
36
+ ```
37
+
38
+ - Visualize data stored on a distant machine with a local viewer:
39
+ ```
40
+ distant$ python src/opentau/scripts/visualize_dataset.py \
41
+ --repo-id lerobot/pusht \
42
+ --episode-index 0 \
43
+ --save 1 \
44
+ --output-dir path/to/directory
45
+
46
+ local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
47
+ local$ rerun lerobot_pusht_episode_0.rrd
48
+ ```
49
+
50
+ - Visualize data stored on a distant machine through streaming:
51
+ (You need to forward the websocket port to the distant machine, with
52
+ `ssh -L 9087:localhost:9087 username@remote-host`)
53
+ ```
54
+ distant$ python src/opentau/scripts/visualize_dataset.py \
55
+ --repo-id lerobot/pusht \
56
+ --episode-index 0 \
57
+ --mode distant \
58
+ --ws-port 9087
59
+
60
+ local$ rerun ws://localhost:9087
61
+ ```
62
+
63
+ """
64
+
65
+ import argparse
66
+ import gc
67
+ import logging
68
+ import time
69
+ from pathlib import Path
70
+ from typing import Iterator
71
+
72
+ import numpy as np
73
+ import rerun as rr
74
+ import torch
75
+ import torch.utils.data
76
+ import tqdm
77
+
78
+ from opentau.datasets.lerobot_dataset import LeRobotDataset
79
+ from opentau.scripts.visualize_dataset_html import create_mock_train_config
80
+
81
+
82
+ class EpisodeSampler(torch.utils.data.Sampler):
83
+ def __init__(self, dataset: LeRobotDataset, episode_index: int):
84
+ from_idx = dataset.episode_data_index["from"][episode_index].item()
85
+ to_idx = dataset.episode_data_index["to"][episode_index].item()
86
+ self.frame_ids = range(from_idx, to_idx)
87
+
88
+ def __iter__(self) -> Iterator:
89
+ return iter(self.frame_ids)
90
+
91
+ def __len__(self) -> int:
92
+ return len(self.frame_ids)
93
+
94
+
95
+ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
96
+ assert chw_float32_torch.dtype == torch.float32
97
+ assert chw_float32_torch.ndim == 3
98
+ c, h, w = chw_float32_torch.shape
99
+ assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
100
+ hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
101
+ return hwc_uint8_numpy
102
+
103
+
104
+ def visualize_dataset(
105
+ dataset: LeRobotDataset,
106
+ episode_index: int,
107
+ batch_size: int = 32,
108
+ num_workers: int = 0,
109
+ mode: str = "local",
110
+ web_port: int = 9090,
111
+ ws_port: int = 9087,
112
+ save: bool = False,
113
+ output_dir: Path | None = None,
114
+ ) -> Path | None:
115
+ if save:
116
+ assert output_dir is not None, (
117
+ "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
118
+ )
119
+
120
+ repo_id = dataset.repo_id
121
+
122
+ logging.info("Loading dataloader")
123
+ episode_sampler = EpisodeSampler(dataset, episode_index)
124
+ dataloader = torch.utils.data.DataLoader(
125
+ dataset,
126
+ num_workers=num_workers,
127
+ batch_size=batch_size,
128
+ sampler=episode_sampler,
129
+ )
130
+
131
+ logging.info("Starting Rerun")
132
+
133
+ if mode not in ["local", "distant"]:
134
+ raise ValueError(mode)
135
+
136
+ spawn_local_viewer = mode == "local" and not save
137
+ rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
138
+
139
+ # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
140
+ # when iterating on a dataloader with `num_workers` > 0
141
+ # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
142
+ gc.collect()
143
+
144
+ if mode == "distant":
145
+ rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
146
+
147
+ logging.info("Logging to Rerun")
148
+
149
+ for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
150
+ # iterate over the batch
151
+ for i in range(len(batch["index"])):
152
+ rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
153
+ rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
154
+
155
+ # display each camera image
156
+ for key in dataset.meta.camera_keys:
157
+ # TODO(rcadene): add `.compress()`? is it lossless?
158
+ rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
159
+
160
+ # display each dimension of action space (e.g. actuators command)
161
+ if "action" in batch:
162
+ for dim_idx, val in enumerate(batch["action"][i]):
163
+ rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
164
+
165
+ # display each dimension of observed state space (e.g. agent position in joint space)
166
+ if "observation.state" in batch:
167
+ for dim_idx, val in enumerate(batch["observation.state"][i]):
168
+ rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
169
+
170
+ if "next.done" in batch:
171
+ rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
172
+
173
+ if "next.reward" in batch:
174
+ rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
175
+
176
+ if "next.success" in batch:
177
+ rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
178
+
179
+ if mode == "local" and save:
180
+ # save .rrd locally
181
+ output_dir = Path(output_dir)
182
+ output_dir.mkdir(parents=True, exist_ok=True)
183
+ repo_id_str = repo_id.replace("/", "_")
184
+ rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
185
+ rr.save(rrd_path)
186
+ return rrd_path
187
+
188
+ elif mode == "distant":
189
+ # stop the process from exiting since it is serving the websocket connection
190
+ try:
191
+ while True:
192
+ time.sleep(1)
193
+ except KeyboardInterrupt:
194
+ print("Ctrl-C received. Exiting.")
195
+
196
+
197
+ def main():
198
+ parser = argparse.ArgumentParser()
199
+
200
+ parser.add_argument(
201
+ "--repo-id",
202
+ type=str,
203
+ required=True,
204
+ help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
205
+ )
206
+ parser.add_argument(
207
+ "--episode-index",
208
+ type=int,
209
+ required=True,
210
+ help="Episode to visualize.",
211
+ )
212
+ parser.add_argument(
213
+ "--root",
214
+ type=Path,
215
+ default=None,
216
+ help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
217
+ )
218
+ parser.add_argument(
219
+ "--output-dir",
220
+ type=Path,
221
+ default=None,
222
+ help="Directory path to write a .rrd file when `--save 1` is set.",
223
+ )
224
+ parser.add_argument(
225
+ "--batch-size",
226
+ type=int,
227
+ default=32,
228
+ help="Batch size loaded by DataLoader.",
229
+ )
230
+ parser.add_argument(
231
+ "--num-workers",
232
+ type=int,
233
+ default=4,
234
+ help="Number of processes of Dataloader for loading the data.",
235
+ )
236
+ parser.add_argument(
237
+ "--mode",
238
+ type=str,
239
+ default="local",
240
+ help=(
241
+ "Mode of viewing between 'local' or 'distant'. "
242
+ "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
243
+ "'distant' creates a server on the distant machine where the data is stored. "
244
+ "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
245
+ ),
246
+ )
247
+ parser.add_argument(
248
+ "--web-port",
249
+ type=int,
250
+ default=9090,
251
+ help="Web port for rerun.io when `--mode distant` is set.",
252
+ )
253
+ parser.add_argument(
254
+ "--ws-port",
255
+ type=int,
256
+ default=9087,
257
+ help="Web socket port for rerun.io when `--mode distant` is set.",
258
+ )
259
+ parser.add_argument(
260
+ "--save",
261
+ type=int,
262
+ default=0,
263
+ help=(
264
+ "Save a .rrd file in the directory provided by `--output-dir`. "
265
+ "It also deactivates the spawning of a viewer. "
266
+ "Visualize the data by running `rerun path/to/file.rrd` on your local machine."
267
+ ),
268
+ )
269
+
270
+ parser.add_argument(
271
+ "--tolerance-s",
272
+ type=float,
273
+ default=1e-4,
274
+ help=(
275
+ "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
276
+ "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
277
+ "If not given, defaults to 1e-4."
278
+ ),
279
+ )
280
+
281
+ args = parser.parse_args()
282
+ kwargs = vars(args)
283
+ repo_id = kwargs.pop("repo_id")
284
+ root = kwargs.pop("root")
285
+ tolerance_s = kwargs.pop("tolerance_s")
286
+
287
+ logging.info("Loading dataset")
288
+ dataset = LeRobotDataset(create_mock_train_config(), repo_id, root=root, tolerance_s=tolerance_s)
289
+
290
+ visualize_dataset(dataset, **vars(args))
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()