opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""SigLip + Gemma Model for Value Function Estimation.
|
|
17
|
+
|
|
18
|
+
This module defines the configuration and model classes for a value function estimator
|
|
19
|
+
that combines a SigLIP vision encoder and a Gemma language model.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from einops import rearrange
|
|
24
|
+
from torch import nn
|
|
25
|
+
from transformers import (
|
|
26
|
+
AutoConfig,
|
|
27
|
+
Gemma3ForCausalLM,
|
|
28
|
+
PretrainedConfig,
|
|
29
|
+
PreTrainedModel,
|
|
30
|
+
SiglipVisionModel,
|
|
31
|
+
)
|
|
32
|
+
from transformers.models.auto import CONFIG_MAPPING
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SiglipGemmaValueConfig(PretrainedConfig):
|
|
36
|
+
"""Configuration class for SiglipGemmaValueModel.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
siglip_config: Configuration for the SigLIP vision model.
|
|
40
|
+
gemma_config: Configuration for the Gemma language model.
|
|
41
|
+
num_value_bins: Number of bins for value discretization. Defaults to 201.
|
|
42
|
+
**kwargs: Additional keyword arguments passed to `PretrainedConfig`.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
model_type = "SiglipGemmaValueModel"
|
|
46
|
+
sub_configs = {"siglip_config": AutoConfig, "gemma_config": AutoConfig}
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
siglip_config: dict | None = None,
|
|
51
|
+
gemma_config: dict | None = None,
|
|
52
|
+
num_value_bins: int = 201,
|
|
53
|
+
**kwargs,
|
|
54
|
+
):
|
|
55
|
+
self.num_value_bins = num_value_bins
|
|
56
|
+
|
|
57
|
+
if siglip_config is None:
|
|
58
|
+
# Default SIGLIP config similar to PaliGemma vision config
|
|
59
|
+
self.siglip_config = CONFIG_MAPPING["siglip_vision_model"](
|
|
60
|
+
hidden_size=1152,
|
|
61
|
+
intermediate_size=4304,
|
|
62
|
+
model_type="siglip_vision_model",
|
|
63
|
+
num_attention_heads=16,
|
|
64
|
+
num_hidden_layers=27,
|
|
65
|
+
num_image_tokens=256,
|
|
66
|
+
patch_size=14,
|
|
67
|
+
projector_hidden_act="gelu_fast",
|
|
68
|
+
torch_dtype="float32",
|
|
69
|
+
vision_use_head=False,
|
|
70
|
+
)
|
|
71
|
+
elif isinstance(siglip_config, dict):
|
|
72
|
+
if "model_type" not in siglip_config:
|
|
73
|
+
siglip_config["model_type"] = "siglip_vision_model"
|
|
74
|
+
|
|
75
|
+
cfg_cls = CONFIG_MAPPING[siglip_config["model_type"]]
|
|
76
|
+
self.siglip_config = cfg_cls(**siglip_config)
|
|
77
|
+
else:
|
|
78
|
+
self.siglip_config = siglip_config
|
|
79
|
+
|
|
80
|
+
if gemma_config is None:
|
|
81
|
+
# Default config for Gemma 3 270M
|
|
82
|
+
# Based on typical scaling: smaller than 1B model
|
|
83
|
+
self.gemma_config = CONFIG_MAPPING["gemma"](
|
|
84
|
+
attention_bias=False,
|
|
85
|
+
attention_dropout=0.0,
|
|
86
|
+
bos_token_id=2,
|
|
87
|
+
eos_token_id=1,
|
|
88
|
+
head_dim=128,
|
|
89
|
+
hidden_act="gelu_pytorch_tanh",
|
|
90
|
+
hidden_activation="gelu_pytorch_tanh",
|
|
91
|
+
hidden_size=640,
|
|
92
|
+
initializer_range=0.02,
|
|
93
|
+
intermediate_size=2048,
|
|
94
|
+
max_position_embeddings=8192,
|
|
95
|
+
model_type="gemma",
|
|
96
|
+
num_attention_heads=8,
|
|
97
|
+
num_hidden_layers=18,
|
|
98
|
+
num_key_value_heads=1,
|
|
99
|
+
pad_token_id=0,
|
|
100
|
+
rms_norm_eps=1e-06,
|
|
101
|
+
rope_theta=10000.0,
|
|
102
|
+
torch_dtype="float32",
|
|
103
|
+
transformers_version="4.48.1",
|
|
104
|
+
use_cache=True,
|
|
105
|
+
vocab_size=257152,
|
|
106
|
+
)
|
|
107
|
+
elif isinstance(gemma_config, dict):
|
|
108
|
+
if "model_type" not in gemma_config:
|
|
109
|
+
gemma_config["model_type"] = "gemma"
|
|
110
|
+
|
|
111
|
+
cfg_cls = CONFIG_MAPPING[gemma_config["model_type"]]
|
|
112
|
+
self.gemma_config = cfg_cls(**gemma_config)
|
|
113
|
+
else:
|
|
114
|
+
self.gemma_config = gemma_config
|
|
115
|
+
|
|
116
|
+
super().__init__(**kwargs)
|
|
117
|
+
|
|
118
|
+
def __post_init__(self):
|
|
119
|
+
super().__post_init__()
|
|
120
|
+
|
|
121
|
+
if self.attention_implementation not in ["eager", "fa2"]:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager' or 'fa2'."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class SiglipGemmaValueModel(PreTrainedModel):
|
|
128
|
+
"""SigLIP + Gemma Model for Value Function Estimation.
|
|
129
|
+
|
|
130
|
+
This model combines a SigLIP vision encoder and a Gemma language model to estimate
|
|
131
|
+
state values. It projects the final hidden state to a set of discretized value bins.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
config: Configuration object of type `SiglipGemmaValueConfig`.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
config_class = SiglipGemmaValueConfig
|
|
138
|
+
|
|
139
|
+
def __init__(self, config: SiglipGemmaValueConfig):
|
|
140
|
+
"""Initializes the SiglipGemmaValueModel.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
config: Configuration object of type `SiglipGemmaValueConfig`.
|
|
144
|
+
"""
|
|
145
|
+
super().__init__(config=config)
|
|
146
|
+
|
|
147
|
+
self.vision_encoder = SiglipVisionModel.from_pretrained("google/siglip2-so400m-patch14-224")
|
|
148
|
+
|
|
149
|
+
# Initialize language model (Gemma 3 270M)
|
|
150
|
+
self.gemma = Gemma3ForCausalLM.from_pretrained("google/gemma-3-270m")
|
|
151
|
+
self.gemma = self.gemma.model # we do not want the LM head
|
|
152
|
+
|
|
153
|
+
# Value head: projects final hidden state to discretized value bins
|
|
154
|
+
self.value_head = nn.Linear(640, config.num_value_bins)
|
|
155
|
+
|
|
156
|
+
def embed_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
157
|
+
"""Embeds images using the SIGLIP vision encoder.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
image: Tensor containing image pixel values.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
torch.Tensor: The embedded image features.
|
|
164
|
+
"""
|
|
165
|
+
# Handle different transformers versions
|
|
166
|
+
if hasattr(self.vision_encoder, "get_image_features"):
|
|
167
|
+
return self.vision_encoder.get_image_features(image)
|
|
168
|
+
else:
|
|
169
|
+
outputs = self.vision_encoder(pixel_values=image)
|
|
170
|
+
return outputs.last_hidden_state
|
|
171
|
+
|
|
172
|
+
def embed_language_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
173
|
+
"""Embeds language tokens using the Gemma embedding layer.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
tokens: Tensor containing language token IDs.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
torch.Tensor: The embedded language tokens.
|
|
180
|
+
"""
|
|
181
|
+
return self.gemma.embed_tokens(tokens)
|
|
182
|
+
|
|
183
|
+
def forward(
|
|
184
|
+
self,
|
|
185
|
+
inputs_embeds: torch.FloatTensor,
|
|
186
|
+
attention_mask: torch.Tensor,
|
|
187
|
+
position_ids: torch.LongTensor,
|
|
188
|
+
) -> torch.Tensor:
|
|
189
|
+
"""Forward pass that processes vision and language inputs and outputs a value.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
inputs_embeds: Tensor of shape [batch_size, sequence_length, embedding_dim]
|
|
193
|
+
containing the combined embeddings of images and text.
|
|
194
|
+
attention_mask: Attention mask for the sequence.
|
|
195
|
+
position_ids: Position IDs for RoPE.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
torch.Tensor: Logits for discretized values of shape [batch_size, num_value_bins].
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
attention_mask = rearrange(attention_mask, "b n1 n2 -> b 1 n1 n2") # support multihead attention
|
|
202
|
+
# HACK: use full attention for sliding attention as well since our context length is almost the same size as the sliding window
|
|
203
|
+
mask_mapping = {
|
|
204
|
+
"full_attention": attention_mask,
|
|
205
|
+
"sliding_attention": attention_mask,
|
|
206
|
+
}
|
|
207
|
+
outputs = self.gemma(
|
|
208
|
+
inputs_embeds=inputs_embeds,
|
|
209
|
+
position_ids=position_ids,
|
|
210
|
+
attention_mask=mask_mapping,
|
|
211
|
+
)
|
|
212
|
+
hidden_states = outputs.last_hidden_state
|
|
213
|
+
|
|
214
|
+
# Extract the last token's hidden state for value prediction
|
|
215
|
+
# Use the last token (which should be the last language token)
|
|
216
|
+
final_hidden = hidden_states[:, -1]
|
|
217
|
+
|
|
218
|
+
# Project to logits for discretized values
|
|
219
|
+
logits = self.value_head(final_hidden)
|
|
220
|
+
|
|
221
|
+
return logits
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from dataclasses import asdict
|
|
19
|
+
from pprint import pformat
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
from sklearn.metrics import r2_score
|
|
24
|
+
from torch.utils.data import DataLoader
|
|
25
|
+
|
|
26
|
+
from opentau.configs import parser
|
|
27
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
28
|
+
from opentau.datasets.factory import make_dataset_mixture
|
|
29
|
+
from opentau.policies.factory import get_policy_class
|
|
30
|
+
from opentau.utils.random_utils import set_seed
|
|
31
|
+
from opentau.utils.utils import (
|
|
32
|
+
attempt_torch_compile,
|
|
33
|
+
auto_torch_device,
|
|
34
|
+
init_logging,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@parser.wrap()
|
|
39
|
+
def inference_main(cfg: TrainPipelineConfig):
|
|
40
|
+
logging.info(pformat(asdict(cfg)))
|
|
41
|
+
# build lerobot dataset and dataloader
|
|
42
|
+
datasets = make_dataset_mixture(cfg)
|
|
43
|
+
|
|
44
|
+
# load trained or finetunned model. Change the batch size to 1 in the config
|
|
45
|
+
|
|
46
|
+
device = auto_torch_device()
|
|
47
|
+
if cfg.seed is not None:
|
|
48
|
+
set_seed(cfg.seed)
|
|
49
|
+
|
|
50
|
+
logging.info("Creating policy")
|
|
51
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
52
|
+
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
|
|
53
|
+
policy = policy.to(device=device, dtype=torch.bfloat16)
|
|
54
|
+
policy.eval()
|
|
55
|
+
policy_sample_actions = attempt_torch_compile(policy.sample_actions, device_hint=device)
|
|
56
|
+
|
|
57
|
+
# Always reset policy before episode to clear out action cache.
|
|
58
|
+
policy.reset()
|
|
59
|
+
|
|
60
|
+
for dataset in datasets.datasets:
|
|
61
|
+
robot_dof = dataset.meta.info["features"]["actions"]["shape"][0]
|
|
62
|
+
assert cfg.max_action_dim >= robot_dof
|
|
63
|
+
print(f"The batch size is {cfg.batch_size}")
|
|
64
|
+
dataloader = DataLoader(dataset, batch_size=cfg.batch_size)
|
|
65
|
+
|
|
66
|
+
pred = []
|
|
67
|
+
truth = []
|
|
68
|
+
with torch.inference_mode():
|
|
69
|
+
for batch in dataloader:
|
|
70
|
+
for key, value in batch.items():
|
|
71
|
+
if isinstance(value, torch.Tensor):
|
|
72
|
+
batch[key] = batch[key].to(device)
|
|
73
|
+
action = policy_sample_actions(batch)
|
|
74
|
+
predicted_action = action.to("cpu", torch.float32).numpy()
|
|
75
|
+
pred.append(predicted_action[0, :, :robot_dof].squeeze(0))
|
|
76
|
+
truth.append(batch["actions"][:, 0, :].squeeze(0)[:robot_dof].to(torch.float32).cpu().numpy())
|
|
77
|
+
|
|
78
|
+
pred = np.stack(pred, axis=0)
|
|
79
|
+
truth = np.stack(truth, axis=0)
|
|
80
|
+
|
|
81
|
+
print(f"the mean squared error loss per dimension is {np.mean((pred - truth) ** 2, axis=0)}")
|
|
82
|
+
|
|
83
|
+
print(f"the r2 score per dimension is {r2_score(pred, truth, multioutput='raw_values')}")
|
|
84
|
+
logging.info("End of inference")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
if __name__ == "__main__":
|
|
88
|
+
init_logging()
|
|
89
|
+
inference_main()
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import copy # Import copy module for deepcopy
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from safetensors.torch import save_file # Removed load_file as it's not used here
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def convert_bin_to_safetensors(input_path: str, output_path: str, map_location: str = "cpu"):
|
|
25
|
+
"""
|
|
26
|
+
Converts a PyTorch .bin checkpoint (state_dict) to .safetensors format.
|
|
27
|
+
|
|
28
|
+
This script attempts to handle cases where tensors might share memory by
|
|
29
|
+
creating a deep copy of the state_dict before saving to .safetensors.
|
|
30
|
+
This ensures that each tensor has its own memory, which is a requirement
|
|
31
|
+
for `safetensors.torch.save_file` when shared tensors are detected.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
input_path (str): Path to the input .bin file.
|
|
35
|
+
output_path (str): Path where the .safetensors file will be saved.
|
|
36
|
+
map_location (str): Device to map the tensors to when loading the .bin file.
|
|
37
|
+
Defaults to 'cpu' to avoid GPU memory issues during conversion.
|
|
38
|
+
Can be 'cuda', 'cuda:0', etc.
|
|
39
|
+
"""
|
|
40
|
+
if not os.path.exists(input_path):
|
|
41
|
+
print(f"Error: Input file not found at '{input_path}'", file=sys.stderr)
|
|
42
|
+
sys.exit(1)
|
|
43
|
+
|
|
44
|
+
if not output_path.lower().endswith(".safetensors"):
|
|
45
|
+
print(
|
|
46
|
+
f"Warning: Output path '{output_path}' does not end with '.safetensors'. Appending it.",
|
|
47
|
+
file=sys.stderr,
|
|
48
|
+
)
|
|
49
|
+
output_path += ".safetensors"
|
|
50
|
+
|
|
51
|
+
print(f"Attempting to load state_dict from '{input_path}'...")
|
|
52
|
+
try:
|
|
53
|
+
# Load the state_dict from the .bin file
|
|
54
|
+
state_dict = torch.load(input_path, map_location=torch.device(map_location)) # nosec B614
|
|
55
|
+
print("State_dict loaded successfully.")
|
|
56
|
+
|
|
57
|
+
# Handle shared memory tensors for safetensors compatibility
|
|
58
|
+
# Create a deep copy of the state_dict to ensure all tensors have unique memory locations.
|
|
59
|
+
# This resolves the "Some tensors share memory" error from safetensors.
|
|
60
|
+
# Note: This might increase the file size if many tensors were originally shared.
|
|
61
|
+
unique_state_dict = {}
|
|
62
|
+
for key, value in state_dict.items():
|
|
63
|
+
if isinstance(value, torch.Tensor):
|
|
64
|
+
unique_state_dict[key] = value.clone().detach()
|
|
65
|
+
else:
|
|
66
|
+
unique_state_dict[key] = copy.deepcopy(value) # Handle non-tensor items (e.g., lists, dicts)
|
|
67
|
+
|
|
68
|
+
# Save the state_dict to .safetensors format
|
|
69
|
+
print(f"Saving state_dict to '{output_path}' in .safetensors format...")
|
|
70
|
+
save_file(unique_state_dict, output_path)
|
|
71
|
+
print(f"Conversion successful! Output saved to '{output_path}'")
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
print(f"An error occurred during conversion: {e}", file=sys.stderr)
|
|
75
|
+
# Provide more specific guidance if it's the known safetensors shared memory error
|
|
76
|
+
if "Some tensors share memory" in str(e):
|
|
77
|
+
print(
|
|
78
|
+
"\nThis error typically occurs when the PyTorch state_dict contains tensors that share the same underlying memory (e.g., `lm_head.weight` and `embed_tokens.weight`)."
|
|
79
|
+
)
|
|
80
|
+
print("The script attempted to resolve this by deep-copying the state_dict before saving.")
|
|
81
|
+
print(
|
|
82
|
+
"If the issue persists, ensure your `safetensors` library is up-to-date (`pip install --upgrade safetensors`)."
|
|
83
|
+
)
|
|
84
|
+
print(
|
|
85
|
+
"For models with complex shared weight patterns, manually loading the model architecture and using `safetensors.torch.save_model(model, output_path)` might be necessary, as it handles shared weights more robustly."
|
|
86
|
+
)
|
|
87
|
+
sys.exit(1)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
if __name__ == "__main__":
|
|
91
|
+
parser = argparse.ArgumentParser(
|
|
92
|
+
description="Convert a PyTorch .bin checkpoint (state_dict) to .safetensors format."
|
|
93
|
+
)
|
|
94
|
+
parser.add_argument("input_file", type=str, help="Path to the input .bin model weights file.")
|
|
95
|
+
parser.add_argument(
|
|
96
|
+
"--output_file",
|
|
97
|
+
type=str,
|
|
98
|
+
help="Path to save the output .safetensors file. If not provided, "
|
|
99
|
+
"it will be inferred from the input filename (e.g., model.bin -> model.safetensors).",
|
|
100
|
+
)
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--map_location",
|
|
103
|
+
type=str,
|
|
104
|
+
default="cpu",
|
|
105
|
+
help="Device to map the tensors to when loading the .bin file. "
|
|
106
|
+
"Defaults to 'cpu'. Can be 'cuda', 'cuda:0', etc.",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
args = parser.parse_args()
|
|
110
|
+
|
|
111
|
+
# If output_file is not provided, infer it
|
|
112
|
+
if args.output_file is None:
|
|
113
|
+
base_name = os.path.splitext(args.input_file)[0]
|
|
114
|
+
args.output_file = base_name + ".safetensors"
|
|
115
|
+
|
|
116
|
+
convert_bin_to_safetensors(args.input_file, args.output_file, args.map_location)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import math
|
|
17
|
+
from collections import Counter
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from functools import partial
|
|
20
|
+
from itertools import accumulate
|
|
21
|
+
from multiprocessing import Pool, cpu_count
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
from tqdm import tqdm
|
|
25
|
+
from transformers import AutoTokenizer, PreTrainedTokenizer
|
|
26
|
+
|
|
27
|
+
from opentau.configs import parser
|
|
28
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
29
|
+
from opentau.datasets.factory import make_dataset_mixture
|
|
30
|
+
from opentau.datasets.lerobot_dataset import BaseDataset
|
|
31
|
+
from opentau.policies.factory import get_policy_class
|
|
32
|
+
from opentau.policies.pi0.modeling_pi0 import PI0Policy
|
|
33
|
+
from opentau.policies.pi05.modeling_pi05 import PI05Policy
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class Args:
|
|
38
|
+
target_cfg: str # path to the training configuration file
|
|
39
|
+
keys: tuple[str] = (
|
|
40
|
+
"response",
|
|
41
|
+
"prompt",
|
|
42
|
+
) # keys to compute max token length for, e.g. ["response", "prompt"]
|
|
43
|
+
num_workers: int | None = None
|
|
44
|
+
chunk_size: int = 1000
|
|
45
|
+
output_path: str | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_tokenizer(cfg: TrainPipelineConfig) -> callable:
|
|
49
|
+
r"""Returns a tokenizer function based on the policy type in the configuration."""
|
|
50
|
+
policy_class = get_policy_class(cfg.policy.type)
|
|
51
|
+
|
|
52
|
+
# TODO: Add `elif` for other policy types if needed
|
|
53
|
+
if issubclass(policy_class, PI0Policy) or issubclass(policy_class, PI05Policy):
|
|
54
|
+
return AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
|
55
|
+
|
|
56
|
+
raise ValueError(f"Unsupported policy type: {cfg.policy.type}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def chunked(dataset: BaseDataset, key: str, chunk_size: int):
|
|
60
|
+
n = len(dataset)
|
|
61
|
+
for start in range(0, n, chunk_size):
|
|
62
|
+
end = min(n, start + chunk_size)
|
|
63
|
+
yield [dataset[i][key] for i in range(start, end)]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def worker_fn(chunk, tokenizer: PreTrainedTokenizer):
|
|
67
|
+
return Counter(len(tokenizer(s)["input_ids"]) for s in chunk)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def to_percentile(counter: Counter) -> dict[int, float]:
|
|
71
|
+
r"""Convert counter to a dictionary with token lengths as keys and their percentile as values."""
|
|
72
|
+
total = counter.total()
|
|
73
|
+
sorted_keys = sorted(counter.keys())
|
|
74
|
+
values = accumulate(counter[k] for k in sorted_keys)
|
|
75
|
+
return {k: v / total for k, v in reversed(list(zip(sorted_keys, values, strict=False)))}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@parser.wrap()
|
|
79
|
+
def main(args: Args):
|
|
80
|
+
cfg = TrainPipelineConfig.from_pretrained(args.target_cfg)
|
|
81
|
+
datasets = make_dataset_mixture(cfg).datasets
|
|
82
|
+
tokenizer = get_tokenizer(cfg)
|
|
83
|
+
worker = partial(worker_fn, tokenizer=tokenizer)
|
|
84
|
+
|
|
85
|
+
output = {}
|
|
86
|
+
for key in args.keys:
|
|
87
|
+
counter = Counter()
|
|
88
|
+
for ds in datasets:
|
|
89
|
+
tasks = tqdm(
|
|
90
|
+
chunked(ds, key, args.chunk_size),
|
|
91
|
+
total=math.ceil(len(ds) / args.chunk_size),
|
|
92
|
+
desc=f"Processing {key} in {ds._get_feature_mapping_key()}",
|
|
93
|
+
)
|
|
94
|
+
# TODO: multiprocessing doesn't seem to speed things up. debug why.
|
|
95
|
+
with Pool(args.num_workers or cpu_count()) as pool:
|
|
96
|
+
parts = pool.imap_unordered(worker, tasks)
|
|
97
|
+
counter = sum(parts, start=counter)
|
|
98
|
+
|
|
99
|
+
output[key] = to_percentile(counter)
|
|
100
|
+
|
|
101
|
+
if args.output_path:
|
|
102
|
+
path = Path(args.output_path)
|
|
103
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
104
|
+
with open(path, "w") as f:
|
|
105
|
+
json.dump(output, f, indent=2)
|
|
106
|
+
else:
|
|
107
|
+
print(json.dumps(output, indent=2))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
if __name__ == "__main__":
|
|
111
|
+
main()
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""Use this script to get a quick summary of your system config.
|
|
19
|
+
It should be able to run without any of OpenTau's dependencies or OpenTau itself installed.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import platform
|
|
23
|
+
|
|
24
|
+
HAS_HF_HUB = True
|
|
25
|
+
HAS_HF_DATASETS = True
|
|
26
|
+
HAS_NP = True
|
|
27
|
+
HAS_TORCH = True
|
|
28
|
+
HAS_OPENTAU = True
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import huggingface_hub
|
|
32
|
+
except ImportError:
|
|
33
|
+
HAS_HF_HUB = False
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
import datasets
|
|
37
|
+
except ImportError:
|
|
38
|
+
HAS_HF_DATASETS = False
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
import numpy as np
|
|
42
|
+
except ImportError:
|
|
43
|
+
HAS_NP = False
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
import torch
|
|
47
|
+
except ImportError:
|
|
48
|
+
HAS_TORCH = False
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
import opentau
|
|
52
|
+
except ImportError:
|
|
53
|
+
HAS_OPENTAU = False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
opentau_version = opentau.__version__ if HAS_OPENTAU else "N/A"
|
|
57
|
+
hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
|
|
58
|
+
hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
|
|
59
|
+
np_version = np.__version__ if HAS_NP else "N/A"
|
|
60
|
+
|
|
61
|
+
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
|
62
|
+
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
|
63
|
+
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def display_sys_info() -> dict:
|
|
67
|
+
"""Run this to get basic system info to help for tracking issues & bugs."""
|
|
68
|
+
info = {
|
|
69
|
+
"`opentau` version": opentau_version,
|
|
70
|
+
"Platform": platform.platform(),
|
|
71
|
+
"Python version": platform.python_version(),
|
|
72
|
+
"Huggingface_hub version": hf_hub_version,
|
|
73
|
+
"Dataset version": hf_datasets_version,
|
|
74
|
+
"Numpy version": np_version,
|
|
75
|
+
"PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
|
|
76
|
+
"Cuda version": cuda_version,
|
|
77
|
+
"Using GPU in script?": "<fill in>",
|
|
78
|
+
# "Using distributed or parallel set-up in script?": "<fill in>",
|
|
79
|
+
}
|
|
80
|
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
|
81
|
+
print(format_dict(info))
|
|
82
|
+
return info
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def format_dict(d: dict) -> str:
|
|
86
|
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
if __name__ == "__main__":
|
|
90
|
+
display_sys_info()
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import time
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
from libero.libero import get_libero_path
|
|
20
|
+
from libero.libero.utils.download_utils import check_libero_dataset, libero_dataset_download
|
|
21
|
+
|
|
22
|
+
from opentau.configs import parser
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class Args:
|
|
27
|
+
suite: str | None = None
|
|
28
|
+
download_dir: str = get_libero_path("datasets")
|
|
29
|
+
|
|
30
|
+
def __post_init__(self):
|
|
31
|
+
if self.suite not in [None, "object", "spatial", "goal", "10", "90"]:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"Invalid suite: {self.suite}. Available suites are: 'object', 'spatial', 'goal', '10', or '90'."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@parser.wrap()
|
|
38
|
+
def main(args: Args):
|
|
39
|
+
# Ask users to specify the download directory of datasets
|
|
40
|
+
download_dir = Path(args.download_dir)
|
|
41
|
+
download_dir.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
download_dir = str(download_dir.resolve())
|
|
43
|
+
print("Datasets will be downloaded to:", download_dir)
|
|
44
|
+
|
|
45
|
+
datasets = "all" if args.suite is None else f"libero_{args.suite}"
|
|
46
|
+
print("Datasets to download:", datasets)
|
|
47
|
+
|
|
48
|
+
libero_dataset_download(datasets=datasets, download_dir=download_dir, use_huggingface=True)
|
|
49
|
+
time.sleep(1)
|
|
50
|
+
check_libero_dataset(download_dir=args.download_dir)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
if __name__ == "__main__":
|
|
54
|
+
main()
|