opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from opentau.configs import parser
|
|
21
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
22
|
+
from opentau.policies.factory import get_policy_class
|
|
23
|
+
from opentau.policies.pi0.modeling_pi0 import PI0Policy
|
|
24
|
+
from opentau.policies.pi05.modeling_pi05 import PI05Policy
|
|
25
|
+
from opentau.utils.monkey_patch import (
|
|
26
|
+
torch_cumsum_patch,
|
|
27
|
+
torch_full_patch,
|
|
28
|
+
torch_pow_patch,
|
|
29
|
+
)
|
|
30
|
+
from opentau.utils.utils import auto_torch_device
|
|
31
|
+
|
|
32
|
+
# Some patches are necessary only for dynamo export, which has current upstream bugs.
|
|
33
|
+
# Nonetheless, we apply them here to ensure future compatibility.
|
|
34
|
+
patches = [
|
|
35
|
+
torch_cumsum_patch, # This is always necessary to load the ONNX artifact without error.
|
|
36
|
+
torch_full_patch,
|
|
37
|
+
torch_pow_patch,
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
KEY_STATES = "key_states"
|
|
41
|
+
VALUE_STATES = "value_states"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class InferenceWrapper(torch.nn.Module):
|
|
45
|
+
r"""Helper class to wrap the robot action decoder for ONNX export,
|
|
46
|
+
such that each input tensor is an individual argument to the `forward` method.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
decoder: PI0Policy | PI05Policy,
|
|
52
|
+
*,
|
|
53
|
+
prefix_pad_masks: torch.Tensor,
|
|
54
|
+
prefix_offsets: torch.Tensor,
|
|
55
|
+
num_cross_att_tokens: int,
|
|
56
|
+
layer_idx: int,
|
|
57
|
+
):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.decoder = decoder
|
|
60
|
+
self.prefix_pad_masks = prefix_pad_masks
|
|
61
|
+
self.prefix_offsets = prefix_offsets
|
|
62
|
+
self.num_cross_att_tokens = num_cross_att_tokens
|
|
63
|
+
self.layer_idx = layer_idx
|
|
64
|
+
|
|
65
|
+
def forward(self, key_states, value_states, state):
|
|
66
|
+
vlm_tokens = (
|
|
67
|
+
{
|
|
68
|
+
self.layer_idx: {
|
|
69
|
+
KEY_STATES: key_states,
|
|
70
|
+
VALUE_STATES: value_states,
|
|
71
|
+
},
|
|
72
|
+
},
|
|
73
|
+
self.prefix_pad_masks,
|
|
74
|
+
self.prefix_offsets,
|
|
75
|
+
self.num_cross_att_tokens,
|
|
76
|
+
)
|
|
77
|
+
observation = {
|
|
78
|
+
"state": state,
|
|
79
|
+
}
|
|
80
|
+
actions = self.decoder.sample_actions(
|
|
81
|
+
observation,
|
|
82
|
+
vlm_token_cache_override=vlm_tokens,
|
|
83
|
+
)
|
|
84
|
+
return actions
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# Get the VLM cache for the dummy observation. This guarantees consistency with post-loading usage.
|
|
88
|
+
def get_vlm_cache(cfg: TrainPipelineConfig, device: torch.device, dtype: torch.dtype):
|
|
89
|
+
logging.info("Getting VLM cache...")
|
|
90
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
91
|
+
cloud_vlm = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
|
|
92
|
+
cloud_vlm.set_execution_target("cloud")
|
|
93
|
+
cloud_vlm.to(device=device, dtype=torch.bfloat16)
|
|
94
|
+
cloud_vlm.eval()
|
|
95
|
+
|
|
96
|
+
vlm_camera_observation = {
|
|
97
|
+
f"camera{i}": torch.zeros((1, 3, *cfg.resolution), dtype=torch.bfloat16, device=device)
|
|
98
|
+
for i in range(cfg.num_cams)
|
|
99
|
+
}
|
|
100
|
+
vlm_observation = {
|
|
101
|
+
**vlm_camera_observation,
|
|
102
|
+
"prompt": ["Pick up yellow lego block and put it in the bin"],
|
|
103
|
+
"state": torch.zeros((1, cfg.max_state_dim), dtype=torch.bfloat16, device=device),
|
|
104
|
+
"img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device),
|
|
105
|
+
}
|
|
106
|
+
cache, prefix_pad_masks, prefix_offsets, num_cross_att_tokens = cloud_vlm.get_vlm_tokens(vlm_observation)
|
|
107
|
+
assert len(cache) == 1, f"Expected only one cache entry for the dummy observation. Got {len(cache)}."
|
|
108
|
+
idx = list(cache)[0]
|
|
109
|
+
return (
|
|
110
|
+
cache[idx][KEY_STATES].to(dtype=dtype),
|
|
111
|
+
cache[idx][VALUE_STATES].to(dtype=dtype),
|
|
112
|
+
prefix_pad_masks,
|
|
113
|
+
prefix_offsets,
|
|
114
|
+
num_cross_att_tokens,
|
|
115
|
+
idx,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@parser.wrap()
|
|
120
|
+
def main(cfg: TrainPipelineConfig):
|
|
121
|
+
device = auto_torch_device()
|
|
122
|
+
dtype = torch.float32
|
|
123
|
+
|
|
124
|
+
# arguments for the dummy observation
|
|
125
|
+
(
|
|
126
|
+
key_states,
|
|
127
|
+
value_states,
|
|
128
|
+
prefix_pad_masks,
|
|
129
|
+
prefix_offsets,
|
|
130
|
+
num_cross_att_tokens,
|
|
131
|
+
layer_idx,
|
|
132
|
+
) = get_vlm_cache(cfg, device, dtype)
|
|
133
|
+
state = torch.zeros((1, cfg.max_state_dim), device=device, dtype=dtype)
|
|
134
|
+
args = (key_states, value_states, state)
|
|
135
|
+
logging.info("Generated example args")
|
|
136
|
+
|
|
137
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
138
|
+
robot_action_decoder = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
|
|
139
|
+
robot_action_decoder.set_execution_target("robot")
|
|
140
|
+
robot_action_decoder.to(device)
|
|
141
|
+
robot_action_decoder.to(dtype=dtype)
|
|
142
|
+
robot_action_decoder.eval()
|
|
143
|
+
inference_wrapper = InferenceWrapper(
|
|
144
|
+
robot_action_decoder,
|
|
145
|
+
prefix_pad_masks=prefix_pad_masks,
|
|
146
|
+
prefix_offsets=prefix_offsets,
|
|
147
|
+
num_cross_att_tokens=num_cross_att_tokens,
|
|
148
|
+
layer_idx=layer_idx,
|
|
149
|
+
)
|
|
150
|
+
logging.info("Loaded policy")
|
|
151
|
+
|
|
152
|
+
logging.info("Applying monkey patches...")
|
|
153
|
+
for patch in patches:
|
|
154
|
+
patch()
|
|
155
|
+
|
|
156
|
+
logging.info("Exporting model to ONNX...")
|
|
157
|
+
with torch.inference_mode():
|
|
158
|
+
path = Path(cfg.policy.pretrained_path) / "robot_action_decoder.onnx"
|
|
159
|
+
path = path.resolve()
|
|
160
|
+
path.parent.mkdir(parents=True, exist_ok=True) # Should be a no-op
|
|
161
|
+
print("Exporting model to ONNX at path:", path)
|
|
162
|
+
print("Current directory:", Path.cwd())
|
|
163
|
+
print("Trying to write to:", path)
|
|
164
|
+
with open(path, "wb"):
|
|
165
|
+
print("Write permissions check passed for:", path)
|
|
166
|
+
print("Running torch.onnx.export...")
|
|
167
|
+
torch.onnx.export(
|
|
168
|
+
inference_wrapper.eval(),
|
|
169
|
+
args,
|
|
170
|
+
path,
|
|
171
|
+
input_names=[KEY_STATES, VALUE_STATES, "state"],
|
|
172
|
+
output_names=["action_chunk"],
|
|
173
|
+
opset_version=18,
|
|
174
|
+
do_constant_folding=False, # constant folding causes weird errors (getting dim -1 from a 0-dim scalar) after forward pass succeeds
|
|
175
|
+
)
|
|
176
|
+
logging.info(f"Successfully exported model to '{path}'.")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
if __name__ == "__main__":
|
|
180
|
+
main()
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from opentau.configs import parser
|
|
20
|
+
from opentau.utils.fake_tensor import FakeTensorContext
|
|
21
|
+
from opentau.utils.utils import auto_torch_device
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class Args:
|
|
26
|
+
device: str | None = None
|
|
27
|
+
num_workers: int = 4
|
|
28
|
+
batch_size: int = 2
|
|
29
|
+
dim_in: int = 3
|
|
30
|
+
dim_out: int = 5
|
|
31
|
+
large_hidden_dim: int = 10**9
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@parser.wrap()
|
|
35
|
+
def main(args: Args):
|
|
36
|
+
device = args.device or auto_torch_device()
|
|
37
|
+
data = [(torch.rand(args.dim_in), torch.rand(args.dim_out)) for _ in range(10)]
|
|
38
|
+
dataloader = torch.utils.data.DataLoader(
|
|
39
|
+
data,
|
|
40
|
+
num_workers=args.num_workers,
|
|
41
|
+
batch_size=args.batch_size,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# ---------------------------------------------------------------
|
|
45
|
+
# Everything inside this context will use FakeTensorMode
|
|
46
|
+
with FakeTensorContext():
|
|
47
|
+
# Create a model in FakeTensorContext shouldn't cost real memory for model parameters.
|
|
48
|
+
model = torch.nn.Sequential(
|
|
49
|
+
torch.nn.Linear(args.dim_in, args.large_hidden_dim),
|
|
50
|
+
torch.nn.Linear(args.large_hidden_dim, args.large_hidden_dim),
|
|
51
|
+
torch.nn.Linear(args.large_hidden_dim, args.dim_out),
|
|
52
|
+
).to(device)
|
|
53
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
|
54
|
+
# End of FakeTensorContext
|
|
55
|
+
# ---------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
print("Model parameters: ")
|
|
58
|
+
for name, param in model.named_parameters():
|
|
59
|
+
print(f"{name}: {param})")
|
|
60
|
+
print()
|
|
61
|
+
|
|
62
|
+
losses = []
|
|
63
|
+
for i, (x, y) in enumerate(dataloader):
|
|
64
|
+
# ---------------------------------------------------------------
|
|
65
|
+
# Everything inside this context will use FakeTensorMode
|
|
66
|
+
# Ideally, we want to iterate the dataloader in FakeTensorMode as well.
|
|
67
|
+
# However, it does not work with multiple workers due to some serialization issue.
|
|
68
|
+
with FakeTensorContext():
|
|
69
|
+
x = x.to(device=device)
|
|
70
|
+
y = y.to(device=device)
|
|
71
|
+
optimizer.zero_grad()
|
|
72
|
+
loss = torch.nn.functional.mse_loss(model(x), y)
|
|
73
|
+
loss.backward()
|
|
74
|
+
|
|
75
|
+
optimizer.step()
|
|
76
|
+
losses.append(loss.item())
|
|
77
|
+
print(
|
|
78
|
+
f"Step {i}: symbolic loss = {loss.item()}, dummy numpy array = {loss.detach().cpu().numpy()}"
|
|
79
|
+
)
|
|
80
|
+
# End of FakeTensorContext
|
|
81
|
+
# ---------------------------------------------------------------
|
|
82
|
+
|
|
83
|
+
print("\nSymbolic mean loss:", sum(losses) / len(losses))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
if __name__ == "__main__":
|
|
87
|
+
main()
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
r"python src/opentau/scripts/get_advantage_and_percentiles.py \
|
|
2
|
+
--config_path=outputs/train/2025-11-29/00-38-59_value/checkpoints/00520000 \
|
|
3
|
+
--batch_size=20 \
|
|
4
|
+
--dataloader_batch_size=20 \
|
|
5
|
+
--dataset_mixture=examples/advantage_config.json"
|
|
6
|
+
|
|
7
|
+
#!/usr/bin/env python
|
|
8
|
+
|
|
9
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
10
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
11
|
+
#
|
|
12
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
13
|
+
# you may not use this file except in compliance with the License.
|
|
14
|
+
# You may obtain a copy of the License at
|
|
15
|
+
#
|
|
16
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
17
|
+
#
|
|
18
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
19
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
20
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
21
|
+
# See the License for the specific language governing permissions and
|
|
22
|
+
# limitations under the License.
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
import sys
|
|
26
|
+
from collections import defaultdict
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
import draccus
|
|
30
|
+
import numpy as np
|
|
31
|
+
import torch
|
|
32
|
+
from torch.utils.data import DataLoader
|
|
33
|
+
|
|
34
|
+
from opentau.configs import parser
|
|
35
|
+
from opentau.configs.default import DatasetMixtureConfig
|
|
36
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
37
|
+
from opentau.datasets.factory import make_dataset
|
|
38
|
+
from opentau.datasets.lerobot_dataset import LeRobotDataset
|
|
39
|
+
from opentau.datasets.utils import ADVANTAGES_PATH
|
|
40
|
+
from opentau.policies.factory import get_policy_class
|
|
41
|
+
from opentau.policies.value.reward import calculate_n_step_return
|
|
42
|
+
from opentau.utils.random_utils import set_seed
|
|
43
|
+
from opentau.utils.utils import (
|
|
44
|
+
auto_torch_device,
|
|
45
|
+
init_logging,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def ensure_primitive(maybe_tensor):
|
|
50
|
+
if isinstance(maybe_tensor, np.ndarray):
|
|
51
|
+
return ensure_primitive(torch.from_numpy(maybe_tensor))
|
|
52
|
+
if isinstance(maybe_tensor, torch.Tensor):
|
|
53
|
+
assert maybe_tensor.numel() == 1, f"Tensor must be a single value, got shape={maybe_tensor.numel()}"
|
|
54
|
+
return maybe_tensor
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
_default0 = defaultdict(int)
|
|
58
|
+
|
|
59
|
+
# Store dataset_mixture_path before filtering (needed for parsing inside main)
|
|
60
|
+
# Handle both --dataset_mixture_path=<path> and --dataset_mixture=<path> (without nested fields)
|
|
61
|
+
_dataset_mixture_path_value = None
|
|
62
|
+
for arg in sys.argv:
|
|
63
|
+
if arg.startswith("--dataset_mixture_path="):
|
|
64
|
+
_dataset_mixture_path_value = arg.split("=", 1)[1]
|
|
65
|
+
break
|
|
66
|
+
elif arg.startswith("--dataset_mixture=") and "." not in arg.split("=", 1)[0]:
|
|
67
|
+
# --dataset_mixture=<path> without nested fields (e.g., not --dataset_mixture.datasets.0.repo_id=...)
|
|
68
|
+
_dataset_mixture_path_value = arg.split("=", 1)[1]
|
|
69
|
+
break
|
|
70
|
+
|
|
71
|
+
# Create a wrapper that filters dataset_mixture path arguments before draccus parsing
|
|
72
|
+
_original_wrap = parser.wrap()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _filter_dataset_mixture_path(fn):
|
|
76
|
+
"""Wrapper that filters --dataset_mixture_path and --dataset_mixture=<path> from sys.argv before draccus sees it."""
|
|
77
|
+
wrapped_fn = _original_wrap(fn)
|
|
78
|
+
|
|
79
|
+
def filtered_wrapper(*args, **kwargs):
|
|
80
|
+
# If config is already provided, just call the function
|
|
81
|
+
if len(args) > 0:
|
|
82
|
+
return wrapped_fn(*args, **kwargs)
|
|
83
|
+
|
|
84
|
+
# Otherwise, filter dataset_mixture path arguments from sys.argv before draccus parses
|
|
85
|
+
original_argv = sys.argv.copy()
|
|
86
|
+
try:
|
|
87
|
+
filtered_args = []
|
|
88
|
+
for arg in sys.argv:
|
|
89
|
+
# Filter --dataset_mixture_path=<path>
|
|
90
|
+
if (
|
|
91
|
+
arg.startswith("--dataset_mixture_path=")
|
|
92
|
+
or arg.startswith("--dataset_mixture=")
|
|
93
|
+
and "." not in arg.split("=", 1)[0]
|
|
94
|
+
):
|
|
95
|
+
continue
|
|
96
|
+
else:
|
|
97
|
+
filtered_args.append(arg)
|
|
98
|
+
sys.argv = filtered_args
|
|
99
|
+
return wrapped_fn(*args, **kwargs)
|
|
100
|
+
finally:
|
|
101
|
+
sys.argv = original_argv
|
|
102
|
+
|
|
103
|
+
return filtered_wrapper
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@_filter_dataset_mixture_path
|
|
107
|
+
def main(cfg: TrainPipelineConfig):
|
|
108
|
+
dataset_mixture_path = _dataset_mixture_path_value
|
|
109
|
+
|
|
110
|
+
if dataset_mixture_path:
|
|
111
|
+
logging.info(f"Loading dataset config from separate file: {dataset_mixture_path}")
|
|
112
|
+
mixture_cfg = draccus.parse(
|
|
113
|
+
config_class=DatasetMixtureConfig, config_path=dataset_mixture_path, args=[]
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
logging.info("Using the dataset mixture config from the TrainPipelineConfig")
|
|
117
|
+
mixture_cfg = cfg.dataset_mixture
|
|
118
|
+
|
|
119
|
+
device = auto_torch_device()
|
|
120
|
+
# torch.autograd.set_detect_anomaly(True)
|
|
121
|
+
|
|
122
|
+
# TODO(shuheng): Do we need the random seed here?
|
|
123
|
+
if cfg.seed is not None:
|
|
124
|
+
set_seed(cfg.seed)
|
|
125
|
+
|
|
126
|
+
logging.info("Creating policy")
|
|
127
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
128
|
+
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
|
|
129
|
+
policy.to(device=device, dtype=torch.bfloat16)
|
|
130
|
+
policy.eval()
|
|
131
|
+
|
|
132
|
+
# Advantages across all datasets
|
|
133
|
+
advantages = []
|
|
134
|
+
|
|
135
|
+
for dataset_idx, dataset_cfg in enumerate(mixture_cfg.datasets):
|
|
136
|
+
logging.info(f"Creating dataset {dataset_idx}")
|
|
137
|
+
dataset = make_dataset(dataset_cfg, cfg, return_advantage_input=True)
|
|
138
|
+
assert isinstance(dataset, LeRobotDataset), (
|
|
139
|
+
f"Expected instance of LeRobotDataset, got {type(dataset)}"
|
|
140
|
+
)
|
|
141
|
+
dataloader = DataLoader(
|
|
142
|
+
dataset,
|
|
143
|
+
batch_size=cfg.batch_size,
|
|
144
|
+
shuffle=False,
|
|
145
|
+
drop_last=False,
|
|
146
|
+
num_workers=cfg.num_workers,
|
|
147
|
+
pin_memory=torch.cuda.is_available(),
|
|
148
|
+
prefetch_factor=cfg.prefetch_factor,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
values = {}
|
|
152
|
+
ds_advantage = {} # per-dataset advantages
|
|
153
|
+
with torch.inference_mode():
|
|
154
|
+
# First pass to get the values
|
|
155
|
+
for batch in dataloader:
|
|
156
|
+
for key, value in batch.items():
|
|
157
|
+
if isinstance(value, torch.Tensor):
|
|
158
|
+
batch[key] = value.to(device)
|
|
159
|
+
|
|
160
|
+
for success, episode_index, episode_end_idx, current_idx, v0 in zip(
|
|
161
|
+
batch["success"],
|
|
162
|
+
batch["episode_index"],
|
|
163
|
+
batch["episode_end_idx"],
|
|
164
|
+
batch["current_idx"],
|
|
165
|
+
policy.predict_value(batch),
|
|
166
|
+
strict=True,
|
|
167
|
+
):
|
|
168
|
+
success, episode_index, episode_end_idx, current_idx, v0 = map(
|
|
169
|
+
ensure_primitive, (success, episode_index, episode_end_idx, current_idx, v0)
|
|
170
|
+
)
|
|
171
|
+
reward = calculate_n_step_return(
|
|
172
|
+
success=success,
|
|
173
|
+
n_steps_look_ahead=cfg.policy.reward_config.N_steps_look_ahead,
|
|
174
|
+
episode_end_idx=episode_end_idx,
|
|
175
|
+
max_episode_length=cfg.policy.reward_config.reward_normalizer,
|
|
176
|
+
current_idx=current_idx,
|
|
177
|
+
c_neg=cfg.policy.reward_config.C_neg,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
values[(episode_index, current_idx)] = {"v0": v0, "reward": reward}
|
|
181
|
+
|
|
182
|
+
# Second pass to compute the advantages
|
|
183
|
+
for batch in dataloader:
|
|
184
|
+
for episode_index, current_idx, timestamp in zip(
|
|
185
|
+
batch["episode_index"],
|
|
186
|
+
batch["current_idx"],
|
|
187
|
+
batch["timestamp"],
|
|
188
|
+
strict=True,
|
|
189
|
+
):
|
|
190
|
+
episode_index, current_idx, timestamp = map(
|
|
191
|
+
ensure_primitive, (episode_index, current_idx, timestamp)
|
|
192
|
+
)
|
|
193
|
+
# check if the value for the next n_steps_look_ahead steps is available, else set it to 0
|
|
194
|
+
look_ahead_idx = current_idx + cfg.policy.reward_config.N_steps_look_ahead
|
|
195
|
+
vn = values.get((episode_index, look_ahead_idx), _default0)["v0"]
|
|
196
|
+
reward = values.get((episode_index, current_idx), _default0)["reward"]
|
|
197
|
+
v0 = values.get((episode_index, current_idx), _default0)["v0"]
|
|
198
|
+
advantage = ensure_primitive(reward + vn - v0)
|
|
199
|
+
advantages.append(advantage)
|
|
200
|
+
ds_advantage[(episode_index, timestamp)] = advantage
|
|
201
|
+
|
|
202
|
+
# Convert tuple keys to strings for JSON serialization
|
|
203
|
+
advantage_data_json = {f"{ep_idx},{ts}": val for (ep_idx, ts), val in ds_advantage.items()}
|
|
204
|
+
|
|
205
|
+
# TODO(shuheng) avoid overwriting existing advantage files.
|
|
206
|
+
with open(Path(dataset.root) / ADVANTAGES_PATH, "w") as f:
|
|
207
|
+
json.dump(advantage_data_json, f, indent=4)
|
|
208
|
+
|
|
209
|
+
# Calculate percentiles of advantages: 0th, 5th, 10th, ..., 100th
|
|
210
|
+
percentiles = list(range(0, 101, 5)) # [0, 5, 10, 15, ..., 100]
|
|
211
|
+
advantage_percentiles = np.percentile(np.array(advantages), percentiles)
|
|
212
|
+
|
|
213
|
+
print("Advantage percentiles for deciding epsilon threshold:")
|
|
214
|
+
for p, val in zip(percentiles, advantage_percentiles, strict=False):
|
|
215
|
+
print(f" {p:03d}th percentile: {val:.6f}")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
if __name__ == "__main__":
|
|
219
|
+
init_logging()
|
|
220
|
+
main()
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
#!/usr/bin/env python
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from dataclasses import asdict
|
|
19
|
+
from pprint import pformat
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from dotenv import load_dotenv
|
|
23
|
+
|
|
24
|
+
from opentau.configs import parser
|
|
25
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
26
|
+
from opentau.planner import HighLevelPlanner, Memory
|
|
27
|
+
from opentau.policies.factory import get_policy_class
|
|
28
|
+
from opentau.utils.random_utils import set_seed
|
|
29
|
+
from opentau.utils.utils import (
|
|
30
|
+
init_logging,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
load_dotenv()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@parser.wrap()
|
|
37
|
+
def inference_main(cfg: TrainPipelineConfig):
|
|
38
|
+
"""
|
|
39
|
+
Reflects the whole pipeline from passing tasks to high level planner to generating actions from low level planner
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
cfg: configuration file. For example look at examples/dev_config.json
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
logging.info(pformat(asdict(cfg)))
|
|
46
|
+
|
|
47
|
+
# Check device is available
|
|
48
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
49
|
+
|
|
50
|
+
if cfg.seed is not None:
|
|
51
|
+
set_seed(cfg.seed)
|
|
52
|
+
|
|
53
|
+
logging.info("Creating policy")
|
|
54
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
55
|
+
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
|
|
56
|
+
policy.to(device)
|
|
57
|
+
policy.to(dtype=torch.bfloat16)
|
|
58
|
+
policy.eval()
|
|
59
|
+
|
|
60
|
+
hlp = HighLevelPlanner()
|
|
61
|
+
mem = Memory(len=2)
|
|
62
|
+
|
|
63
|
+
# compile the model if possible
|
|
64
|
+
if hasattr(torch, "compile"):
|
|
65
|
+
logging.info("Attempting to compile the policy with torch.compile()...")
|
|
66
|
+
try:
|
|
67
|
+
# Other options: "default", "max-autotune" (longer compile time)
|
|
68
|
+
policy = torch.compile(policy)
|
|
69
|
+
logging.info("Policy compiled successfully.")
|
|
70
|
+
except Exception as e:
|
|
71
|
+
logging.warning(f"torch.compile failed with error: {e}. Proceeding without compilation.")
|
|
72
|
+
else:
|
|
73
|
+
logging.warning(
|
|
74
|
+
"torch.compile is not available. Requires PyTorch 2.0+. Proceeding without compilation."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Always reset policy before episode to clear out action cache.
|
|
78
|
+
policy.reset()
|
|
79
|
+
|
|
80
|
+
for i in range(5):
|
|
81
|
+
# create dummy observation for pi05
|
|
82
|
+
camera_observations = {
|
|
83
|
+
f"camera{i}": torch.zeros((1, 3, *cfg.resolution), dtype=torch.bfloat16, device=device)
|
|
84
|
+
for i in range(cfg.num_cams)
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
task = "Pick up yellow lego block and put it in the bin"
|
|
88
|
+
|
|
89
|
+
sub_task = hlp.inference(camera_observations, "", task, "gpt4o", mem).split("\n")[1].split('"')[1]
|
|
90
|
+
|
|
91
|
+
if mem:
|
|
92
|
+
mem.add_conversation("assistant", [{"type": "text", "text": sub_task}])
|
|
93
|
+
|
|
94
|
+
logging.info(f"{sub_task}")
|
|
95
|
+
observation = {
|
|
96
|
+
**camera_observations,
|
|
97
|
+
"state": torch.zeros((1, cfg.max_state_dim), dtype=torch.bfloat16, device=device),
|
|
98
|
+
"prompt": [sub_task],
|
|
99
|
+
"img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device),
|
|
100
|
+
"action_is_pad": torch.zeros((1, cfg.action_chunk), dtype=torch.bool, device=device),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
with torch.inference_mode():
|
|
104
|
+
for _ in range(1000):
|
|
105
|
+
action = policy.select_action(observation)
|
|
106
|
+
action = action.to("cpu").numpy()
|
|
107
|
+
print(f"Output dummy action: {action}")
|
|
108
|
+
|
|
109
|
+
logging.info("End of inference")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
if __name__ == "__main__":
|
|
113
|
+
init_logging()
|
|
114
|
+
inference_main()
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
#!/usr/bin/env python
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from dataclasses import asdict
|
|
19
|
+
from pprint import pformat
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
from opentau.configs import parser
|
|
24
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
25
|
+
from opentau.policies.factory import get_policy_class
|
|
26
|
+
from opentau.utils.random_utils import set_seed
|
|
27
|
+
from opentau.utils.utils import (
|
|
28
|
+
attempt_torch_compile,
|
|
29
|
+
auto_torch_device,
|
|
30
|
+
create_dummy_observation,
|
|
31
|
+
init_logging,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@parser.wrap()
|
|
36
|
+
def inference_main(cfg: TrainPipelineConfig):
|
|
37
|
+
logging.info(pformat(asdict(cfg)))
|
|
38
|
+
|
|
39
|
+
# Check device is available
|
|
40
|
+
device = auto_torch_device()
|
|
41
|
+
|
|
42
|
+
if cfg.seed is not None:
|
|
43
|
+
set_seed(cfg.seed)
|
|
44
|
+
|
|
45
|
+
logging.info("Creating policy")
|
|
46
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
47
|
+
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
|
|
48
|
+
policy.to(device=device, dtype=torch.bfloat16)
|
|
49
|
+
policy.eval()
|
|
50
|
+
policy = attempt_torch_compile(policy, device_hint=device)
|
|
51
|
+
|
|
52
|
+
# Always reset policy before episode to clear out action cache.
|
|
53
|
+
policy.reset()
|
|
54
|
+
|
|
55
|
+
observation = create_dummy_observation(cfg, device, dtype=torch.bfloat16)
|
|
56
|
+
|
|
57
|
+
print(observation.keys())
|
|
58
|
+
|
|
59
|
+
with torch.inference_mode():
|
|
60
|
+
for _ in range(1000):
|
|
61
|
+
action = policy.select_action(observation)
|
|
62
|
+
action = action.to("cpu", torch.float32).numpy()
|
|
63
|
+
print(f"Output shape: {action.shape}")
|
|
64
|
+
|
|
65
|
+
logging.info("End of inference")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
if __name__ == "__main__":
|
|
69
|
+
init_logging()
|
|
70
|
+
inference_main()
|