opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
opentau/scripts/train.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
from contextlib import nullcontext
|
|
21
|
+
from pprint import pformat
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
import accelerate
|
|
25
|
+
import torch
|
|
26
|
+
from accelerate.optimizer import AcceleratedOptimizer
|
|
27
|
+
from accelerate.scheduler import AcceleratedScheduler
|
|
28
|
+
from accelerate.utils import DistributedDataParallelKwargs, gather_object
|
|
29
|
+
from termcolor import colored
|
|
30
|
+
|
|
31
|
+
from opentau.configs import parser
|
|
32
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
33
|
+
from opentau.datasets.factory import make_dataset_mixture
|
|
34
|
+
from opentau.datasets.utils import cycle
|
|
35
|
+
from opentau.envs.factory import make_envs
|
|
36
|
+
from opentau.envs.utils import close_envs
|
|
37
|
+
from opentau.optim.factory import make_optimizer_and_scheduler
|
|
38
|
+
from opentau.policies.factory import make_policy
|
|
39
|
+
from opentau.policies.pretrained import PreTrainedPolicy
|
|
40
|
+
from opentau.scripts.eval import consolidate_eval_info, eval_policy_all
|
|
41
|
+
from opentau.utils.accelerate_utils import set_proc_accelerator
|
|
42
|
+
from opentau.utils.logging_utils import AverageMeter, MetricsTracker
|
|
43
|
+
from opentau.utils.random_utils import set_seed
|
|
44
|
+
from opentau.utils.train_utils import (
|
|
45
|
+
get_step_checkpoint_dir,
|
|
46
|
+
get_step_identifier,
|
|
47
|
+
load_training_state,
|
|
48
|
+
load_training_step,
|
|
49
|
+
prune_old_checkpoints,
|
|
50
|
+
save_checkpoint,
|
|
51
|
+
)
|
|
52
|
+
from opentau.utils.utils import (
|
|
53
|
+
encode_accelerator_state_dict,
|
|
54
|
+
format_big_number,
|
|
55
|
+
init_logging,
|
|
56
|
+
is_launched_with_accelerate,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def update_policy(
|
|
61
|
+
train_config: TrainPipelineConfig,
|
|
62
|
+
train_metrics: MetricsTracker,
|
|
63
|
+
policy: PreTrainedPolicy,
|
|
64
|
+
batch: Any,
|
|
65
|
+
optimizer: AcceleratedOptimizer,
|
|
66
|
+
grad_clip_norm: float,
|
|
67
|
+
accelerator: accelerate.Accelerator,
|
|
68
|
+
lr_scheduler: AcceleratedScheduler | None = None,
|
|
69
|
+
) -> tuple[MetricsTracker, dict]:
|
|
70
|
+
policy.train()
|
|
71
|
+
losses = policy.forward(batch)
|
|
72
|
+
loss = (
|
|
73
|
+
train_config.loss_weighting["MSE"] * losses["MSE"] + train_config.loss_weighting["CE"] * losses["CE"]
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# accelerator.backward(loss)
|
|
77
|
+
# accelerator.unscale_gradients(optimizer=optimizer)
|
|
78
|
+
|
|
79
|
+
# if accelerator.sync_gradients:
|
|
80
|
+
# grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
|
81
|
+
# if accelerator.is_main_process:
|
|
82
|
+
# train_metrics.grad_norm = grad_norm
|
|
83
|
+
|
|
84
|
+
# optimizer.step()
|
|
85
|
+
# optimizer.zero_grad()
|
|
86
|
+
|
|
87
|
+
# Step through pytorch scheduler at every batch instead of epoch
|
|
88
|
+
if lr_scheduler is not None:
|
|
89
|
+
lr_scheduler.step()
|
|
90
|
+
|
|
91
|
+
# This calls `torch.distributed.all_gather_into_tensor` under the hood, which is not so efficient.
|
|
92
|
+
# We don't actually want to broadcast the gathered tensors to all processes, but only to the main process.
|
|
93
|
+
# Nonetheless, we still do this for correctness, safety, and simplicity.
|
|
94
|
+
_first_loss_tensor = next(lt for lt in losses.values() if isinstance(lt, torch.Tensor))
|
|
95
|
+
zero = torch.tensor(0.0, device=_first_loss_tensor.device, dtype=_first_loss_tensor.dtype)
|
|
96
|
+
loss = accelerator.gather_for_metrics(loss).mean().item()
|
|
97
|
+
mse_loss = accelerator.gather_for_metrics(losses["MSE"]).to(dtype=torch.float32).mean().item()
|
|
98
|
+
ce_loss = accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item()
|
|
99
|
+
l1_loss = accelerator.gather_for_metrics(losses.get("L1", zero)).to(dtype=torch.float32).mean().item()
|
|
100
|
+
accuracy = (
|
|
101
|
+
accelerator.gather_for_metrics(losses.get("Accuracy", zero)).to(dtype=torch.float32).mean().item()
|
|
102
|
+
)
|
|
103
|
+
# This actually calls `.update` method of the `AverageMeter` class. This operation is not idempotent.
|
|
104
|
+
# See MetricsTracker.__setattr__ for more details.
|
|
105
|
+
# In other words, setting `train_metrics.loss = 1` and `train_metrics.loss = 2` consecutively results in
|
|
106
|
+
# an average of 1.5 when formatted as a string, not just 2.
|
|
107
|
+
if accelerator.is_main_process:
|
|
108
|
+
train_metrics.loss = loss
|
|
109
|
+
train_metrics.mse_loss = mse_loss
|
|
110
|
+
train_metrics.ce_loss = ce_loss
|
|
111
|
+
train_metrics.l1_loss = l1_loss
|
|
112
|
+
train_metrics.accuracy = accuracy
|
|
113
|
+
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
|
114
|
+
|
|
115
|
+
return train_metrics
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@parser.wrap()
|
|
119
|
+
def train(cfg: TrainPipelineConfig):
|
|
120
|
+
cfg.validate()
|
|
121
|
+
|
|
122
|
+
accelerator_kwargs = {
|
|
123
|
+
"step_scheduler_with_optimizer": False,
|
|
124
|
+
"split_batches": False, # split_batches == True is not working anyways
|
|
125
|
+
"kwargs_handlers": [DistributedDataParallelKwargs(find_unused_parameters=True)],
|
|
126
|
+
}
|
|
127
|
+
if cfg.wandb.enable:
|
|
128
|
+
accelerator_kwargs["log_with"] = "wandb"
|
|
129
|
+
if cfg.gradient_accumulation_steps > 1:
|
|
130
|
+
accelerator_kwargs["gradient_accumulation_steps"] = cfg.gradient_accumulation_steps
|
|
131
|
+
|
|
132
|
+
accelerator = accelerate.Accelerator(**accelerator_kwargs)
|
|
133
|
+
init_logging(accelerator, level=logging.DEBUG if cfg.debug else logging.INFO)
|
|
134
|
+
# Register accelerator globally for use in other modules, (e.g., detect current rank, etc.)
|
|
135
|
+
set_proc_accelerator(accelerator)
|
|
136
|
+
|
|
137
|
+
logging.info(pformat(cfg.to_dict()))
|
|
138
|
+
|
|
139
|
+
if accelerator.is_main_process:
|
|
140
|
+
accelerator_config = encode_accelerator_state_dict(accelerator.state.__dict__)
|
|
141
|
+
logging.info(pformat(accelerator_config))
|
|
142
|
+
|
|
143
|
+
# Ensure `gradient_accumulation_steps` is consistent between TrainPipelineConfig and DeepSpeedConfig
|
|
144
|
+
if accelerator.distributed_type == accelerate.DistributedType.DEEPSPEED:
|
|
145
|
+
deepspeed_config, deepspeed_key = accelerator.deepspeed_plugin.hf_ds_config.find_config_node(
|
|
146
|
+
"gradient_accumulation_steps"
|
|
147
|
+
)
|
|
148
|
+
ds_grad_acc_steps = deepspeed_config.get(deepspeed_key, 1)
|
|
149
|
+
if ds_grad_acc_steps != cfg.gradient_accumulation_steps:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"The `gradient_accumulation_steps` in TrainPipelineConfig does not match the value "
|
|
152
|
+
f"specified in DeepSpeedConfig {cfg.gradient_accumulation_steps} != {ds_grad_acc_steps}. " # nosec B608
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if cfg.wandb.enable:
|
|
156
|
+
step = load_training_step(cfg.checkpoint_path) if cfg.resume else None
|
|
157
|
+
slurm_dict = {k: v for k, v in os.environ.items() if k.startswith("SLURM_")}
|
|
158
|
+
accelerator.init_trackers(
|
|
159
|
+
cfg.wandb.project,
|
|
160
|
+
config={**cfg.to_dict(), "accelerator": accelerator_config, "slurm": slurm_dict},
|
|
161
|
+
init_kwargs={"wandb": cfg.wandb.to_wandb_kwargs(step=step)},
|
|
162
|
+
)
|
|
163
|
+
tracker = accelerator.get_tracker("wandb", unwrap=True)
|
|
164
|
+
cfg.wandb.run_id = tracker.id
|
|
165
|
+
logging.info(f"tracker initialized with wandb job id: {tracker.id}")
|
|
166
|
+
|
|
167
|
+
if cfg.seed is not None:
|
|
168
|
+
set_seed(cfg.seed, accelerator=accelerator)
|
|
169
|
+
|
|
170
|
+
# Enable anomaly detection for debugging NaN/Inf values
|
|
171
|
+
# (warning: large computational overhead)
|
|
172
|
+
torch.autograd.set_detect_anomaly(cfg.trace_nans)
|
|
173
|
+
if cfg.trace_nans:
|
|
174
|
+
logging.warning("Anomaly detection is enabled. This may significantly slow down training.")
|
|
175
|
+
else:
|
|
176
|
+
logging.info("Anomaly detection is disabled.")
|
|
177
|
+
|
|
178
|
+
logging.info("Creating dataset")
|
|
179
|
+
dataset = make_dataset_mixture(cfg)
|
|
180
|
+
|
|
181
|
+
# Create environment used for evaluating checkpoints during training on simulation data.
|
|
182
|
+
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
|
183
|
+
eval_envs = None
|
|
184
|
+
if cfg.eval_freq > 0 and cfg.env is not None:
|
|
185
|
+
logging.info("Creating env")
|
|
186
|
+
eval_envs = make_envs(
|
|
187
|
+
cfg.env, cfg, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
logging.info("Creating policy")
|
|
191
|
+
policy = make_policy(cfg=cfg.policy, ds_meta=dataset.meta)
|
|
192
|
+
policy.to(torch.bfloat16)
|
|
193
|
+
logging.info("Creating optimizer and scheduler")
|
|
194
|
+
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
|
195
|
+
|
|
196
|
+
step = 0 # number of policy updates (forward + backward + optim)
|
|
197
|
+
|
|
198
|
+
if accelerator.is_main_process:
|
|
199
|
+
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
|
200
|
+
num_total_params = sum(p.numel() for p in policy.parameters())
|
|
201
|
+
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
|
202
|
+
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
|
203
|
+
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
|
204
|
+
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
|
205
|
+
|
|
206
|
+
dataloader = dataset.get_dataloader()
|
|
207
|
+
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
|
208
|
+
policy, optimizer, dataloader, lr_scheduler
|
|
209
|
+
)
|
|
210
|
+
dl_iter = cycle(dataloader)
|
|
211
|
+
|
|
212
|
+
# Register the LR scheduler for checkpointing
|
|
213
|
+
accelerator.register_for_checkpointing(lr_scheduler)
|
|
214
|
+
|
|
215
|
+
if cfg.resume:
|
|
216
|
+
# load accelerator state
|
|
217
|
+
# This will load the model, optimizer, and lr_scheduler state
|
|
218
|
+
accelerator.load_state(cfg.checkpoint_path)
|
|
219
|
+
|
|
220
|
+
# all processes should load the step & rng states
|
|
221
|
+
step = load_training_state(cfg.checkpoint_path)
|
|
222
|
+
logging.info(f"Resuming training from checkpoint {cfg.checkpoint_path}")
|
|
223
|
+
|
|
224
|
+
policy.train()
|
|
225
|
+
|
|
226
|
+
# setup metrics tracker to average metrics over the logging interval
|
|
227
|
+
train_metrics = {
|
|
228
|
+
"loss": AverageMeter("total_loss", ":.3f"),
|
|
229
|
+
"mse_loss": AverageMeter("mse_loss", ":.3f"),
|
|
230
|
+
"ce_loss": AverageMeter("ce_loss", ":.3f"),
|
|
231
|
+
"l1_loss": AverageMeter("l1_loss", ":.3f"),
|
|
232
|
+
"accuracy": AverageMeter("accuracy", ":.3f"),
|
|
233
|
+
"lr": AverageMeter("lr", ":0.1e"),
|
|
234
|
+
"grad_norm": AverageMeter("grad_norm", ":.3f"),
|
|
235
|
+
}
|
|
236
|
+
train_tracker = MetricsTracker(
|
|
237
|
+
cfg.batch_size * accelerator.num_processes, # split_batches are not working
|
|
238
|
+
train_metrics,
|
|
239
|
+
initial_step=step,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if accelerator.is_main_process:
|
|
243
|
+
logging.info("Start offline training on a fixed dataset")
|
|
244
|
+
|
|
245
|
+
for _ in range(step, cfg.steps):
|
|
246
|
+
for _ in range(cfg.gradient_accumulation_steps):
|
|
247
|
+
with accelerator.accumulate(policy) if cfg.gradient_accumulation_steps > 1 else nullcontext():
|
|
248
|
+
logging.debug(f"{step=}, {accelerator.sync_gradients=}")
|
|
249
|
+
batch = next(dl_iter)
|
|
250
|
+
|
|
251
|
+
train_tracker = update_policy(
|
|
252
|
+
cfg,
|
|
253
|
+
train_tracker,
|
|
254
|
+
policy,
|
|
255
|
+
batch,
|
|
256
|
+
optimizer,
|
|
257
|
+
cfg.optimizer.grad_clip_norm,
|
|
258
|
+
accelerator=accelerator,
|
|
259
|
+
lr_scheduler=lr_scheduler,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
|
263
|
+
# increment `step` here.
|
|
264
|
+
step += 1
|
|
265
|
+
train_tracker.step()
|
|
266
|
+
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
|
267
|
+
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and cfg.save_checkpoint
|
|
268
|
+
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
|
269
|
+
|
|
270
|
+
# Only `train_tracker` on the main process keeps useful statistics,
|
|
271
|
+
# because we guarded it with if accelerator.is_main_process in the `update_policy` function.
|
|
272
|
+
if is_log_step and accelerator.is_main_process:
|
|
273
|
+
logging.info(train_tracker)
|
|
274
|
+
log_dict = train_tracker.to_dict(use_avg=True)
|
|
275
|
+
accelerator.log({"Training Loss": log_dict["loss"]}, step=step)
|
|
276
|
+
accelerator.log({"MSE Loss": log_dict["mse_loss"]}, step=step)
|
|
277
|
+
accelerator.log({"CE Loss": log_dict["ce_loss"]}, step=step)
|
|
278
|
+
accelerator.log({"L1 Loss": log_dict["l1_loss"]}, step=step)
|
|
279
|
+
accelerator.log({"Accuracy": log_dict["accuracy"]}, step=step)
|
|
280
|
+
accelerator.log({"Learning Rate": log_dict["lr"]}, step=step)
|
|
281
|
+
accelerator.log({"Grad Norm": log_dict["grad_norm"]}, step=step)
|
|
282
|
+
accelerator.log({"Num Samples": log_dict["samples"]}, step=step)
|
|
283
|
+
train_tracker.reset_averages()
|
|
284
|
+
|
|
285
|
+
if is_saving_step:
|
|
286
|
+
# TODO: investigate whether this barrier is needed
|
|
287
|
+
accelerator.wait_for_everyone()
|
|
288
|
+
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
|
289
|
+
|
|
290
|
+
# save the accelerator state
|
|
291
|
+
# This will save the model, optimizer, and lr_scheduler state
|
|
292
|
+
accelerator.save_state(checkpoint_dir)
|
|
293
|
+
|
|
294
|
+
# save axillary objects such as configs, training step, and rng state
|
|
295
|
+
if accelerator.is_main_process:
|
|
296
|
+
logging.info(f"Checkpoint policy after step {step}")
|
|
297
|
+
cfg.policy.pretrained_path = checkpoint_dir
|
|
298
|
+
save_checkpoint(checkpoint_dir, step, cfg)
|
|
299
|
+
if cfg.last_checkpoint_only:
|
|
300
|
+
prune_old_checkpoints(checkpoint_dir)
|
|
301
|
+
|
|
302
|
+
# This barrier is probably necessary to ensure
|
|
303
|
+
# other processes wait for the main process to finish saving
|
|
304
|
+
accelerator.wait_for_everyone()
|
|
305
|
+
|
|
306
|
+
if is_eval_step and eval_envs:
|
|
307
|
+
step_id = get_step_identifier(step, cfg.steps)
|
|
308
|
+
logging.info(f"Eval policy at step {step}")
|
|
309
|
+
with (
|
|
310
|
+
torch.no_grad(),
|
|
311
|
+
torch.autocast(device_type=accelerator.device.type) if cfg.policy.use_amp else nullcontext(),
|
|
312
|
+
):
|
|
313
|
+
eval_info = eval_policy_all(
|
|
314
|
+
eval_envs,
|
|
315
|
+
policy,
|
|
316
|
+
cfg.eval.n_episodes,
|
|
317
|
+
cfg,
|
|
318
|
+
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
|
319
|
+
max_episodes_rendered=cfg.eval.max_episodes_rendered,
|
|
320
|
+
start_seed=cfg.seed,
|
|
321
|
+
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
eval_info = gather_object([eval_info]) # gather across all accelerator processes
|
|
325
|
+
if accelerator.is_main_process:
|
|
326
|
+
eval_info = consolidate_eval_info(eval_info)
|
|
327
|
+
# overall metrics (suite-agnostic)
|
|
328
|
+
aggregated = eval_info["overall"]
|
|
329
|
+
|
|
330
|
+
# optional: per-suite logging
|
|
331
|
+
for suite, suite_info in eval_info.items():
|
|
332
|
+
logging.info("Suite %s aggregated: %s", suite, suite_info)
|
|
333
|
+
|
|
334
|
+
# meters/tracker
|
|
335
|
+
eval_metrics = {
|
|
336
|
+
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
|
337
|
+
"pc_success": AverageMeter("success", ":.1f"),
|
|
338
|
+
"eval_per_gpu_s": AverageMeter("eval_per_gpu_s", ":.3f"),
|
|
339
|
+
}
|
|
340
|
+
eval_tracker = MetricsTracker(
|
|
341
|
+
cfg.batch_size,
|
|
342
|
+
eval_metrics,
|
|
343
|
+
initial_step=step,
|
|
344
|
+
)
|
|
345
|
+
eval_tracker.eval_per_gpu_s = aggregated.get("eval_per_gpu_s", float("nan"))
|
|
346
|
+
eval_tracker.avg_sum_reward = aggregated.get("avg_sum_reward", float("nan"))
|
|
347
|
+
eval_tracker.pc_success = aggregated.get("pc_success", float("nan"))
|
|
348
|
+
logging.info(eval_tracker)
|
|
349
|
+
eval_dict = eval_tracker.to_dict(use_avg=True)
|
|
350
|
+
accelerator.log({"Success Rate": eval_dict["pc_success"]}, step=step)
|
|
351
|
+
accelerator.log({"Evaluation Time": eval_dict["eval_per_gpu_s"]}, step=step)
|
|
352
|
+
for group, v in eval_info["per_group"].items():
|
|
353
|
+
accelerator.log({f"Success/{group}": v["pc_success"]}, step=step)
|
|
354
|
+
|
|
355
|
+
# Save eval_info to the same directory as videos
|
|
356
|
+
videos_dir = cfg.output_dir / "eval" / f"videos_step_{step_id}"
|
|
357
|
+
with open(videos_dir / "eval_info.json", "w") as f:
|
|
358
|
+
json.dump(eval_info, f, indent=2)
|
|
359
|
+
|
|
360
|
+
if is_eval_step:
|
|
361
|
+
# This barrier is to ensure all processes finishes evaluation before the next training step
|
|
362
|
+
# Some processes might be slower than others
|
|
363
|
+
accelerator.wait_for_everyone()
|
|
364
|
+
|
|
365
|
+
if cfg.eval_freq > 0 and eval_envs:
|
|
366
|
+
close_envs(eval_envs)
|
|
367
|
+
|
|
368
|
+
accelerator.end_training()
|
|
369
|
+
if accelerator.is_main_process:
|
|
370
|
+
logging.info("End of training")
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
if __name__ == "__main__":
|
|
374
|
+
if not is_launched_with_accelerate():
|
|
375
|
+
raise Exception(
|
|
376
|
+
"This script should be launched with accelerate. Please use `accelerate launch` to run this script."
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
train()
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
|
18
|
+
|
|
19
|
+
Note: The last frame of the episode doesn't always correspond to a final state.
|
|
20
|
+
That's because our datasets are composed of transition from state to state up to
|
|
21
|
+
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
|
22
|
+
However, there might not be a transition from a final state to another state.
|
|
23
|
+
|
|
24
|
+
Note: This script aims to visualize the data used to train the neural networks.
|
|
25
|
+
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
|
26
|
+
lossy compression artifacts since these images have been decoded from compressed mp4 videos to
|
|
27
|
+
save disk space. The compression factor applied has been tuned to not affect success rate.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
|
|
31
|
+
- Visualize data stored on a local machine:
|
|
32
|
+
```
|
|
33
|
+
local$ python src/opentau/scripts/visualize_dataset.py \
|
|
34
|
+
--repo-id lerobot/pusht \
|
|
35
|
+
--episode-index 0
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
- Visualize data stored on a distant machine with a local viewer:
|
|
39
|
+
```
|
|
40
|
+
distant$ python src/opentau/scripts/visualize_dataset.py \
|
|
41
|
+
--repo-id lerobot/pusht \
|
|
42
|
+
--episode-index 0 \
|
|
43
|
+
--save 1 \
|
|
44
|
+
--output-dir path/to/directory
|
|
45
|
+
|
|
46
|
+
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
|
|
47
|
+
local$ rerun lerobot_pusht_episode_0.rrd
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
- Visualize data stored on a distant machine through streaming:
|
|
51
|
+
(You need to forward the websocket port to the distant machine, with
|
|
52
|
+
`ssh -L 9087:localhost:9087 username@remote-host`)
|
|
53
|
+
```
|
|
54
|
+
distant$ python src/opentau/scripts/visualize_dataset.py \
|
|
55
|
+
--repo-id lerobot/pusht \
|
|
56
|
+
--episode-index 0 \
|
|
57
|
+
--mode distant \
|
|
58
|
+
--ws-port 9087
|
|
59
|
+
|
|
60
|
+
local$ rerun ws://localhost:9087
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
import argparse
|
|
66
|
+
import gc
|
|
67
|
+
import logging
|
|
68
|
+
import time
|
|
69
|
+
from pathlib import Path
|
|
70
|
+
from typing import Iterator
|
|
71
|
+
|
|
72
|
+
import numpy as np
|
|
73
|
+
import rerun as rr
|
|
74
|
+
import torch
|
|
75
|
+
import torch.utils.data
|
|
76
|
+
import tqdm
|
|
77
|
+
|
|
78
|
+
from opentau.datasets.lerobot_dataset import LeRobotDataset
|
|
79
|
+
from opentau.scripts.visualize_dataset_html import create_mock_train_config
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class EpisodeSampler(torch.utils.data.Sampler):
|
|
83
|
+
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
|
84
|
+
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
|
85
|
+
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
|
86
|
+
self.frame_ids = range(from_idx, to_idx)
|
|
87
|
+
|
|
88
|
+
def __iter__(self) -> Iterator:
|
|
89
|
+
return iter(self.frame_ids)
|
|
90
|
+
|
|
91
|
+
def __len__(self) -> int:
|
|
92
|
+
return len(self.frame_ids)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
|
96
|
+
assert chw_float32_torch.dtype == torch.float32
|
|
97
|
+
assert chw_float32_torch.ndim == 3
|
|
98
|
+
c, h, w = chw_float32_torch.shape
|
|
99
|
+
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
|
100
|
+
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
|
101
|
+
return hwc_uint8_numpy
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def visualize_dataset(
|
|
105
|
+
dataset: LeRobotDataset,
|
|
106
|
+
episode_index: int,
|
|
107
|
+
batch_size: int = 32,
|
|
108
|
+
num_workers: int = 0,
|
|
109
|
+
mode: str = "local",
|
|
110
|
+
web_port: int = 9090,
|
|
111
|
+
ws_port: int = 9087,
|
|
112
|
+
save: bool = False,
|
|
113
|
+
output_dir: Path | None = None,
|
|
114
|
+
) -> Path | None:
|
|
115
|
+
if save:
|
|
116
|
+
assert output_dir is not None, (
|
|
117
|
+
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
repo_id = dataset.repo_id
|
|
121
|
+
|
|
122
|
+
logging.info("Loading dataloader")
|
|
123
|
+
episode_sampler = EpisodeSampler(dataset, episode_index)
|
|
124
|
+
dataloader = torch.utils.data.DataLoader(
|
|
125
|
+
dataset,
|
|
126
|
+
num_workers=num_workers,
|
|
127
|
+
batch_size=batch_size,
|
|
128
|
+
sampler=episode_sampler,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
logging.info("Starting Rerun")
|
|
132
|
+
|
|
133
|
+
if mode not in ["local", "distant"]:
|
|
134
|
+
raise ValueError(mode)
|
|
135
|
+
|
|
136
|
+
spawn_local_viewer = mode == "local" and not save
|
|
137
|
+
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
|
138
|
+
|
|
139
|
+
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
|
140
|
+
# when iterating on a dataloader with `num_workers` > 0
|
|
141
|
+
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
|
|
142
|
+
gc.collect()
|
|
143
|
+
|
|
144
|
+
if mode == "distant":
|
|
145
|
+
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
|
|
146
|
+
|
|
147
|
+
logging.info("Logging to Rerun")
|
|
148
|
+
|
|
149
|
+
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
|
150
|
+
# iterate over the batch
|
|
151
|
+
for i in range(len(batch["index"])):
|
|
152
|
+
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
|
|
153
|
+
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
|
154
|
+
|
|
155
|
+
# display each camera image
|
|
156
|
+
for key in dataset.meta.camera_keys:
|
|
157
|
+
# TODO(rcadene): add `.compress()`? is it lossless?
|
|
158
|
+
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
|
159
|
+
|
|
160
|
+
# display each dimension of action space (e.g. actuators command)
|
|
161
|
+
if "action" in batch:
|
|
162
|
+
for dim_idx, val in enumerate(batch["action"][i]):
|
|
163
|
+
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
|
164
|
+
|
|
165
|
+
# display each dimension of observed state space (e.g. agent position in joint space)
|
|
166
|
+
if "observation.state" in batch:
|
|
167
|
+
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
|
168
|
+
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
|
|
169
|
+
|
|
170
|
+
if "next.done" in batch:
|
|
171
|
+
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
|
|
172
|
+
|
|
173
|
+
if "next.reward" in batch:
|
|
174
|
+
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
|
|
175
|
+
|
|
176
|
+
if "next.success" in batch:
|
|
177
|
+
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
|
|
178
|
+
|
|
179
|
+
if mode == "local" and save:
|
|
180
|
+
# save .rrd locally
|
|
181
|
+
output_dir = Path(output_dir)
|
|
182
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
183
|
+
repo_id_str = repo_id.replace("/", "_")
|
|
184
|
+
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
|
|
185
|
+
rr.save(rrd_path)
|
|
186
|
+
return rrd_path
|
|
187
|
+
|
|
188
|
+
elif mode == "distant":
|
|
189
|
+
# stop the process from exiting since it is serving the websocket connection
|
|
190
|
+
try:
|
|
191
|
+
while True:
|
|
192
|
+
time.sleep(1)
|
|
193
|
+
except KeyboardInterrupt:
|
|
194
|
+
print("Ctrl-C received. Exiting.")
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def main():
|
|
198
|
+
parser = argparse.ArgumentParser()
|
|
199
|
+
|
|
200
|
+
parser.add_argument(
|
|
201
|
+
"--repo-id",
|
|
202
|
+
type=str,
|
|
203
|
+
required=True,
|
|
204
|
+
help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
|
205
|
+
)
|
|
206
|
+
parser.add_argument(
|
|
207
|
+
"--episode-index",
|
|
208
|
+
type=int,
|
|
209
|
+
required=True,
|
|
210
|
+
help="Episode to visualize.",
|
|
211
|
+
)
|
|
212
|
+
parser.add_argument(
|
|
213
|
+
"--root",
|
|
214
|
+
type=Path,
|
|
215
|
+
default=None,
|
|
216
|
+
help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
|
217
|
+
)
|
|
218
|
+
parser.add_argument(
|
|
219
|
+
"--output-dir",
|
|
220
|
+
type=Path,
|
|
221
|
+
default=None,
|
|
222
|
+
help="Directory path to write a .rrd file when `--save 1` is set.",
|
|
223
|
+
)
|
|
224
|
+
parser.add_argument(
|
|
225
|
+
"--batch-size",
|
|
226
|
+
type=int,
|
|
227
|
+
default=32,
|
|
228
|
+
help="Batch size loaded by DataLoader.",
|
|
229
|
+
)
|
|
230
|
+
parser.add_argument(
|
|
231
|
+
"--num-workers",
|
|
232
|
+
type=int,
|
|
233
|
+
default=4,
|
|
234
|
+
help="Number of processes of Dataloader for loading the data.",
|
|
235
|
+
)
|
|
236
|
+
parser.add_argument(
|
|
237
|
+
"--mode",
|
|
238
|
+
type=str,
|
|
239
|
+
default="local",
|
|
240
|
+
help=(
|
|
241
|
+
"Mode of viewing between 'local' or 'distant'. "
|
|
242
|
+
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
|
243
|
+
"'distant' creates a server on the distant machine where the data is stored. "
|
|
244
|
+
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
|
245
|
+
),
|
|
246
|
+
)
|
|
247
|
+
parser.add_argument(
|
|
248
|
+
"--web-port",
|
|
249
|
+
type=int,
|
|
250
|
+
default=9090,
|
|
251
|
+
help="Web port for rerun.io when `--mode distant` is set.",
|
|
252
|
+
)
|
|
253
|
+
parser.add_argument(
|
|
254
|
+
"--ws-port",
|
|
255
|
+
type=int,
|
|
256
|
+
default=9087,
|
|
257
|
+
help="Web socket port for rerun.io when `--mode distant` is set.",
|
|
258
|
+
)
|
|
259
|
+
parser.add_argument(
|
|
260
|
+
"--save",
|
|
261
|
+
type=int,
|
|
262
|
+
default=0,
|
|
263
|
+
help=(
|
|
264
|
+
"Save a .rrd file in the directory provided by `--output-dir`. "
|
|
265
|
+
"It also deactivates the spawning of a viewer. "
|
|
266
|
+
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
|
267
|
+
),
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
parser.add_argument(
|
|
271
|
+
"--tolerance-s",
|
|
272
|
+
type=float,
|
|
273
|
+
default=1e-4,
|
|
274
|
+
help=(
|
|
275
|
+
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
|
276
|
+
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
|
277
|
+
"If not given, defaults to 1e-4."
|
|
278
|
+
),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
args = parser.parse_args()
|
|
282
|
+
kwargs = vars(args)
|
|
283
|
+
repo_id = kwargs.pop("repo_id")
|
|
284
|
+
root = kwargs.pop("root")
|
|
285
|
+
tolerance_s = kwargs.pop("tolerance_s")
|
|
286
|
+
|
|
287
|
+
logging.info("Loading dataset")
|
|
288
|
+
dataset = LeRobotDataset(create_mock_train_config(), repo_id, root=root, tolerance_s=tolerance_s)
|
|
289
|
+
|
|
290
|
+
visualize_dataset(dataset, **vars(args))
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
if __name__ == "__main__":
|
|
294
|
+
main()
|