opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,221 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """SigLip + Gemma Model for Value Function Estimation.
17
+
18
+ This module defines the configuration and model classes for a value function estimator
19
+ that combines a SigLIP vision encoder and a Gemma language model.
20
+ """
21
+
22
+ import torch
23
+ from einops import rearrange
24
+ from torch import nn
25
+ from transformers import (
26
+ AutoConfig,
27
+ Gemma3ForCausalLM,
28
+ PretrainedConfig,
29
+ PreTrainedModel,
30
+ SiglipVisionModel,
31
+ )
32
+ from transformers.models.auto import CONFIG_MAPPING
33
+
34
+
35
+ class SiglipGemmaValueConfig(PretrainedConfig):
36
+ """Configuration class for SiglipGemmaValueModel.
37
+
38
+ Args:
39
+ siglip_config: Configuration for the SigLIP vision model.
40
+ gemma_config: Configuration for the Gemma language model.
41
+ num_value_bins: Number of bins for value discretization. Defaults to 201.
42
+ **kwargs: Additional keyword arguments passed to `PretrainedConfig`.
43
+ """
44
+
45
+ model_type = "SiglipGemmaValueModel"
46
+ sub_configs = {"siglip_config": AutoConfig, "gemma_config": AutoConfig}
47
+
48
+ def __init__(
49
+ self,
50
+ siglip_config: dict | None = None,
51
+ gemma_config: dict | None = None,
52
+ num_value_bins: int = 201,
53
+ **kwargs,
54
+ ):
55
+ self.num_value_bins = num_value_bins
56
+
57
+ if siglip_config is None:
58
+ # Default SIGLIP config similar to PaliGemma vision config
59
+ self.siglip_config = CONFIG_MAPPING["siglip_vision_model"](
60
+ hidden_size=1152,
61
+ intermediate_size=4304,
62
+ model_type="siglip_vision_model",
63
+ num_attention_heads=16,
64
+ num_hidden_layers=27,
65
+ num_image_tokens=256,
66
+ patch_size=14,
67
+ projector_hidden_act="gelu_fast",
68
+ torch_dtype="float32",
69
+ vision_use_head=False,
70
+ )
71
+ elif isinstance(siglip_config, dict):
72
+ if "model_type" not in siglip_config:
73
+ siglip_config["model_type"] = "siglip_vision_model"
74
+
75
+ cfg_cls = CONFIG_MAPPING[siglip_config["model_type"]]
76
+ self.siglip_config = cfg_cls(**siglip_config)
77
+ else:
78
+ self.siglip_config = siglip_config
79
+
80
+ if gemma_config is None:
81
+ # Default config for Gemma 3 270M
82
+ # Based on typical scaling: smaller than 1B model
83
+ self.gemma_config = CONFIG_MAPPING["gemma"](
84
+ attention_bias=False,
85
+ attention_dropout=0.0,
86
+ bos_token_id=2,
87
+ eos_token_id=1,
88
+ head_dim=128,
89
+ hidden_act="gelu_pytorch_tanh",
90
+ hidden_activation="gelu_pytorch_tanh",
91
+ hidden_size=640,
92
+ initializer_range=0.02,
93
+ intermediate_size=2048,
94
+ max_position_embeddings=8192,
95
+ model_type="gemma",
96
+ num_attention_heads=8,
97
+ num_hidden_layers=18,
98
+ num_key_value_heads=1,
99
+ pad_token_id=0,
100
+ rms_norm_eps=1e-06,
101
+ rope_theta=10000.0,
102
+ torch_dtype="float32",
103
+ transformers_version="4.48.1",
104
+ use_cache=True,
105
+ vocab_size=257152,
106
+ )
107
+ elif isinstance(gemma_config, dict):
108
+ if "model_type" not in gemma_config:
109
+ gemma_config["model_type"] = "gemma"
110
+
111
+ cfg_cls = CONFIG_MAPPING[gemma_config["model_type"]]
112
+ self.gemma_config = cfg_cls(**gemma_config)
113
+ else:
114
+ self.gemma_config = gemma_config
115
+
116
+ super().__init__(**kwargs)
117
+
118
+ def __post_init__(self):
119
+ super().__post_init__()
120
+
121
+ if self.attention_implementation not in ["eager", "fa2"]:
122
+ raise ValueError(
123
+ f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager' or 'fa2'."
124
+ )
125
+
126
+
127
+ class SiglipGemmaValueModel(PreTrainedModel):
128
+ """SigLIP + Gemma Model for Value Function Estimation.
129
+
130
+ This model combines a SigLIP vision encoder and a Gemma language model to estimate
131
+ state values. It projects the final hidden state to a set of discretized value bins.
132
+
133
+ Args:
134
+ config: Configuration object of type `SiglipGemmaValueConfig`.
135
+ """
136
+
137
+ config_class = SiglipGemmaValueConfig
138
+
139
+ def __init__(self, config: SiglipGemmaValueConfig):
140
+ """Initializes the SiglipGemmaValueModel.
141
+
142
+ Args:
143
+ config: Configuration object of type `SiglipGemmaValueConfig`.
144
+ """
145
+ super().__init__(config=config)
146
+
147
+ self.vision_encoder = SiglipVisionModel.from_pretrained("google/siglip2-so400m-patch14-224")
148
+
149
+ # Initialize language model (Gemma 3 270M)
150
+ self.gemma = Gemma3ForCausalLM.from_pretrained("google/gemma-3-270m")
151
+ self.gemma = self.gemma.model # we do not want the LM head
152
+
153
+ # Value head: projects final hidden state to discretized value bins
154
+ self.value_head = nn.Linear(640, config.num_value_bins)
155
+
156
+ def embed_image(self, image: torch.Tensor) -> torch.Tensor:
157
+ """Embeds images using the SIGLIP vision encoder.
158
+
159
+ Args:
160
+ image: Tensor containing image pixel values.
161
+
162
+ Returns:
163
+ torch.Tensor: The embedded image features.
164
+ """
165
+ # Handle different transformers versions
166
+ if hasattr(self.vision_encoder, "get_image_features"):
167
+ return self.vision_encoder.get_image_features(image)
168
+ else:
169
+ outputs = self.vision_encoder(pixel_values=image)
170
+ return outputs.last_hidden_state
171
+
172
+ def embed_language_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
173
+ """Embeds language tokens using the Gemma embedding layer.
174
+
175
+ Args:
176
+ tokens: Tensor containing language token IDs.
177
+
178
+ Returns:
179
+ torch.Tensor: The embedded language tokens.
180
+ """
181
+ return self.gemma.embed_tokens(tokens)
182
+
183
+ def forward(
184
+ self,
185
+ inputs_embeds: torch.FloatTensor,
186
+ attention_mask: torch.Tensor,
187
+ position_ids: torch.LongTensor,
188
+ ) -> torch.Tensor:
189
+ """Forward pass that processes vision and language inputs and outputs a value.
190
+
191
+ Args:
192
+ inputs_embeds: Tensor of shape [batch_size, sequence_length, embedding_dim]
193
+ containing the combined embeddings of images and text.
194
+ attention_mask: Attention mask for the sequence.
195
+ position_ids: Position IDs for RoPE.
196
+
197
+ Returns:
198
+ torch.Tensor: Logits for discretized values of shape [batch_size, num_value_bins].
199
+ """
200
+
201
+ attention_mask = rearrange(attention_mask, "b n1 n2 -> b 1 n1 n2") # support multihead attention
202
+ # HACK: use full attention for sliding attention as well since our context length is almost the same size as the sliding window
203
+ mask_mapping = {
204
+ "full_attention": attention_mask,
205
+ "sliding_attention": attention_mask,
206
+ }
207
+ outputs = self.gemma(
208
+ inputs_embeds=inputs_embeds,
209
+ position_ids=position_ids,
210
+ attention_mask=mask_mapping,
211
+ )
212
+ hidden_states = outputs.last_hidden_state
213
+
214
+ # Extract the last token's hidden state for value prediction
215
+ # Use the last token (which should be the last language token)
216
+ final_hidden = hidden_states[:, -1]
217
+
218
+ # Project to logits for discretized values
219
+ logits = self.value_head(final_hidden)
220
+
221
+ return logits
@@ -0,0 +1,89 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from dataclasses import asdict
19
+ from pprint import pformat
20
+
21
+ import numpy as np
22
+ import torch
23
+ from sklearn.metrics import r2_score
24
+ from torch.utils.data import DataLoader
25
+
26
+ from opentau.configs import parser
27
+ from opentau.configs.train import TrainPipelineConfig
28
+ from opentau.datasets.factory import make_dataset_mixture
29
+ from opentau.policies.factory import get_policy_class
30
+ from opentau.utils.random_utils import set_seed
31
+ from opentau.utils.utils import (
32
+ attempt_torch_compile,
33
+ auto_torch_device,
34
+ init_logging,
35
+ )
36
+
37
+
38
+ @parser.wrap()
39
+ def inference_main(cfg: TrainPipelineConfig):
40
+ logging.info(pformat(asdict(cfg)))
41
+ # build lerobot dataset and dataloader
42
+ datasets = make_dataset_mixture(cfg)
43
+
44
+ # load trained or finetunned model. Change the batch size to 1 in the config
45
+
46
+ device = auto_torch_device()
47
+ if cfg.seed is not None:
48
+ set_seed(cfg.seed)
49
+
50
+ logging.info("Creating policy")
51
+ policy_class = get_policy_class(cfg.policy.type)
52
+ policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy)
53
+ policy = policy.to(device=device, dtype=torch.bfloat16)
54
+ policy.eval()
55
+ policy_sample_actions = attempt_torch_compile(policy.sample_actions, device_hint=device)
56
+
57
+ # Always reset policy before episode to clear out action cache.
58
+ policy.reset()
59
+
60
+ for dataset in datasets.datasets:
61
+ robot_dof = dataset.meta.info["features"]["actions"]["shape"][0]
62
+ assert cfg.max_action_dim >= robot_dof
63
+ print(f"The batch size is {cfg.batch_size}")
64
+ dataloader = DataLoader(dataset, batch_size=cfg.batch_size)
65
+
66
+ pred = []
67
+ truth = []
68
+ with torch.inference_mode():
69
+ for batch in dataloader:
70
+ for key, value in batch.items():
71
+ if isinstance(value, torch.Tensor):
72
+ batch[key] = batch[key].to(device)
73
+ action = policy_sample_actions(batch)
74
+ predicted_action = action.to("cpu", torch.float32).numpy()
75
+ pred.append(predicted_action[0, :, :robot_dof].squeeze(0))
76
+ truth.append(batch["actions"][:, 0, :].squeeze(0)[:robot_dof].to(torch.float32).cpu().numpy())
77
+
78
+ pred = np.stack(pred, axis=0)
79
+ truth = np.stack(truth, axis=0)
80
+
81
+ print(f"the mean squared error loss per dimension is {np.mean((pred - truth) ** 2, axis=0)}")
82
+
83
+ print(f"the r2 score per dimension is {r2_score(pred, truth, multioutput='raw_values')}")
84
+ logging.info("End of inference")
85
+
86
+
87
+ if __name__ == "__main__":
88
+ init_logging()
89
+ inference_main()
@@ -0,0 +1,116 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import copy # Import copy module for deepcopy
17
+ import os
18
+ import sys
19
+
20
+ import torch
21
+ from safetensors.torch import save_file # Removed load_file as it's not used here
22
+
23
+
24
+ def convert_bin_to_safetensors(input_path: str, output_path: str, map_location: str = "cpu"):
25
+ """
26
+ Converts a PyTorch .bin checkpoint (state_dict) to .safetensors format.
27
+
28
+ This script attempts to handle cases where tensors might share memory by
29
+ creating a deep copy of the state_dict before saving to .safetensors.
30
+ This ensures that each tensor has its own memory, which is a requirement
31
+ for `safetensors.torch.save_file` when shared tensors are detected.
32
+
33
+ Args:
34
+ input_path (str): Path to the input .bin file.
35
+ output_path (str): Path where the .safetensors file will be saved.
36
+ map_location (str): Device to map the tensors to when loading the .bin file.
37
+ Defaults to 'cpu' to avoid GPU memory issues during conversion.
38
+ Can be 'cuda', 'cuda:0', etc.
39
+ """
40
+ if not os.path.exists(input_path):
41
+ print(f"Error: Input file not found at '{input_path}'", file=sys.stderr)
42
+ sys.exit(1)
43
+
44
+ if not output_path.lower().endswith(".safetensors"):
45
+ print(
46
+ f"Warning: Output path '{output_path}' does not end with '.safetensors'. Appending it.",
47
+ file=sys.stderr,
48
+ )
49
+ output_path += ".safetensors"
50
+
51
+ print(f"Attempting to load state_dict from '{input_path}'...")
52
+ try:
53
+ # Load the state_dict from the .bin file
54
+ state_dict = torch.load(input_path, map_location=torch.device(map_location)) # nosec B614
55
+ print("State_dict loaded successfully.")
56
+
57
+ # Handle shared memory tensors for safetensors compatibility
58
+ # Create a deep copy of the state_dict to ensure all tensors have unique memory locations.
59
+ # This resolves the "Some tensors share memory" error from safetensors.
60
+ # Note: This might increase the file size if many tensors were originally shared.
61
+ unique_state_dict = {}
62
+ for key, value in state_dict.items():
63
+ if isinstance(value, torch.Tensor):
64
+ unique_state_dict[key] = value.clone().detach()
65
+ else:
66
+ unique_state_dict[key] = copy.deepcopy(value) # Handle non-tensor items (e.g., lists, dicts)
67
+
68
+ # Save the state_dict to .safetensors format
69
+ print(f"Saving state_dict to '{output_path}' in .safetensors format...")
70
+ save_file(unique_state_dict, output_path)
71
+ print(f"Conversion successful! Output saved to '{output_path}'")
72
+
73
+ except Exception as e:
74
+ print(f"An error occurred during conversion: {e}", file=sys.stderr)
75
+ # Provide more specific guidance if it's the known safetensors shared memory error
76
+ if "Some tensors share memory" in str(e):
77
+ print(
78
+ "\nThis error typically occurs when the PyTorch state_dict contains tensors that share the same underlying memory (e.g., `lm_head.weight` and `embed_tokens.weight`)."
79
+ )
80
+ print("The script attempted to resolve this by deep-copying the state_dict before saving.")
81
+ print(
82
+ "If the issue persists, ensure your `safetensors` library is up-to-date (`pip install --upgrade safetensors`)."
83
+ )
84
+ print(
85
+ "For models with complex shared weight patterns, manually loading the model architecture and using `safetensors.torch.save_model(model, output_path)` might be necessary, as it handles shared weights more robustly."
86
+ )
87
+ sys.exit(1)
88
+
89
+
90
+ if __name__ == "__main__":
91
+ parser = argparse.ArgumentParser(
92
+ description="Convert a PyTorch .bin checkpoint (state_dict) to .safetensors format."
93
+ )
94
+ parser.add_argument("input_file", type=str, help="Path to the input .bin model weights file.")
95
+ parser.add_argument(
96
+ "--output_file",
97
+ type=str,
98
+ help="Path to save the output .safetensors file. If not provided, "
99
+ "it will be inferred from the input filename (e.g., model.bin -> model.safetensors).",
100
+ )
101
+ parser.add_argument(
102
+ "--map_location",
103
+ type=str,
104
+ default="cpu",
105
+ help="Device to map the tensors to when loading the .bin file. "
106
+ "Defaults to 'cpu'. Can be 'cuda', 'cuda:0', etc.",
107
+ )
108
+
109
+ args = parser.parse_args()
110
+
111
+ # If output_file is not provided, infer it
112
+ if args.output_file is None:
113
+ base_name = os.path.splitext(args.input_file)[0]
114
+ args.output_file = base_name + ".safetensors"
115
+
116
+ convert_bin_to_safetensors(args.input_file, args.output_file, args.map_location)
@@ -0,0 +1,111 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import math
17
+ from collections import Counter
18
+ from dataclasses import dataclass
19
+ from functools import partial
20
+ from itertools import accumulate
21
+ from multiprocessing import Pool, cpu_count
22
+ from pathlib import Path
23
+
24
+ from tqdm import tqdm
25
+ from transformers import AutoTokenizer, PreTrainedTokenizer
26
+
27
+ from opentau.configs import parser
28
+ from opentau.configs.train import TrainPipelineConfig
29
+ from opentau.datasets.factory import make_dataset_mixture
30
+ from opentau.datasets.lerobot_dataset import BaseDataset
31
+ from opentau.policies.factory import get_policy_class
32
+ from opentau.policies.pi0.modeling_pi0 import PI0Policy
33
+ from opentau.policies.pi05.modeling_pi05 import PI05Policy
34
+
35
+
36
+ @dataclass
37
+ class Args:
38
+ target_cfg: str # path to the training configuration file
39
+ keys: tuple[str] = (
40
+ "response",
41
+ "prompt",
42
+ ) # keys to compute max token length for, e.g. ["response", "prompt"]
43
+ num_workers: int | None = None
44
+ chunk_size: int = 1000
45
+ output_path: str | None = None
46
+
47
+
48
+ def get_tokenizer(cfg: TrainPipelineConfig) -> callable:
49
+ r"""Returns a tokenizer function based on the policy type in the configuration."""
50
+ policy_class = get_policy_class(cfg.policy.type)
51
+
52
+ # TODO: Add `elif` for other policy types if needed
53
+ if issubclass(policy_class, PI0Policy) or issubclass(policy_class, PI05Policy):
54
+ return AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
55
+
56
+ raise ValueError(f"Unsupported policy type: {cfg.policy.type}")
57
+
58
+
59
+ def chunked(dataset: BaseDataset, key: str, chunk_size: int):
60
+ n = len(dataset)
61
+ for start in range(0, n, chunk_size):
62
+ end = min(n, start + chunk_size)
63
+ yield [dataset[i][key] for i in range(start, end)]
64
+
65
+
66
+ def worker_fn(chunk, tokenizer: PreTrainedTokenizer):
67
+ return Counter(len(tokenizer(s)["input_ids"]) for s in chunk)
68
+
69
+
70
+ def to_percentile(counter: Counter) -> dict[int, float]:
71
+ r"""Convert counter to a dictionary with token lengths as keys and their percentile as values."""
72
+ total = counter.total()
73
+ sorted_keys = sorted(counter.keys())
74
+ values = accumulate(counter[k] for k in sorted_keys)
75
+ return {k: v / total for k, v in reversed(list(zip(sorted_keys, values, strict=False)))}
76
+
77
+
78
+ @parser.wrap()
79
+ def main(args: Args):
80
+ cfg = TrainPipelineConfig.from_pretrained(args.target_cfg)
81
+ datasets = make_dataset_mixture(cfg).datasets
82
+ tokenizer = get_tokenizer(cfg)
83
+ worker = partial(worker_fn, tokenizer=tokenizer)
84
+
85
+ output = {}
86
+ for key in args.keys:
87
+ counter = Counter()
88
+ for ds in datasets:
89
+ tasks = tqdm(
90
+ chunked(ds, key, args.chunk_size),
91
+ total=math.ceil(len(ds) / args.chunk_size),
92
+ desc=f"Processing {key} in {ds._get_feature_mapping_key()}",
93
+ )
94
+ # TODO: multiprocessing doesn't seem to speed things up. debug why.
95
+ with Pool(args.num_workers or cpu_count()) as pool:
96
+ parts = pool.imap_unordered(worker, tasks)
97
+ counter = sum(parts, start=counter)
98
+
99
+ output[key] = to_percentile(counter)
100
+
101
+ if args.output_path:
102
+ path = Path(args.output_path)
103
+ path.parent.mkdir(parents=True, exist_ok=True)
104
+ with open(path, "w") as f:
105
+ json.dump(output, f, indent=2)
106
+ else:
107
+ print(json.dumps(output, indent=2))
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
@@ -0,0 +1,90 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Use this script to get a quick summary of your system config.
19
+ It should be able to run without any of OpenTau's dependencies or OpenTau itself installed.
20
+ """
21
+
22
+ import platform
23
+
24
+ HAS_HF_HUB = True
25
+ HAS_HF_DATASETS = True
26
+ HAS_NP = True
27
+ HAS_TORCH = True
28
+ HAS_OPENTAU = True
29
+
30
+ try:
31
+ import huggingface_hub
32
+ except ImportError:
33
+ HAS_HF_HUB = False
34
+
35
+ try:
36
+ import datasets
37
+ except ImportError:
38
+ HAS_HF_DATASETS = False
39
+
40
+ try:
41
+ import numpy as np
42
+ except ImportError:
43
+ HAS_NP = False
44
+
45
+ try:
46
+ import torch
47
+ except ImportError:
48
+ HAS_TORCH = False
49
+
50
+ try:
51
+ import opentau
52
+ except ImportError:
53
+ HAS_OPENTAU = False
54
+
55
+
56
+ opentau_version = opentau.__version__ if HAS_OPENTAU else "N/A"
57
+ hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
58
+ hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
59
+ np_version = np.__version__ if HAS_NP else "N/A"
60
+
61
+ torch_version = torch.__version__ if HAS_TORCH else "N/A"
62
+ torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
63
+ cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
64
+
65
+
66
+ def display_sys_info() -> dict:
67
+ """Run this to get basic system info to help for tracking issues & bugs."""
68
+ info = {
69
+ "`opentau` version": opentau_version,
70
+ "Platform": platform.platform(),
71
+ "Python version": platform.python_version(),
72
+ "Huggingface_hub version": hf_hub_version,
73
+ "Dataset version": hf_datasets_version,
74
+ "Numpy version": np_version,
75
+ "PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
76
+ "Cuda version": cuda_version,
77
+ "Using GPU in script?": "<fill in>",
78
+ # "Using distributed or parallel set-up in script?": "<fill in>",
79
+ }
80
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
81
+ print(format_dict(info))
82
+ return info
83
+
84
+
85
+ def format_dict(d: dict) -> str:
86
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
87
+
88
+
89
+ if __name__ == "__main__":
90
+ display_sys_info()
@@ -0,0 +1,54 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+ from dataclasses import dataclass
17
+ from pathlib import Path
18
+
19
+ from libero.libero import get_libero_path
20
+ from libero.libero.utils.download_utils import check_libero_dataset, libero_dataset_download
21
+
22
+ from opentau.configs import parser
23
+
24
+
25
+ @dataclass
26
+ class Args:
27
+ suite: str | None = None
28
+ download_dir: str = get_libero_path("datasets")
29
+
30
+ def __post_init__(self):
31
+ if self.suite not in [None, "object", "spatial", "goal", "10", "90"]:
32
+ raise ValueError(
33
+ f"Invalid suite: {self.suite}. Available suites are: 'object', 'spatial', 'goal', '10', or '90'."
34
+ )
35
+
36
+
37
+ @parser.wrap()
38
+ def main(args: Args):
39
+ # Ask users to specify the download directory of datasets
40
+ download_dir = Path(args.download_dir)
41
+ download_dir.mkdir(parents=True, exist_ok=True)
42
+ download_dir = str(download_dir.resolve())
43
+ print("Datasets will be downloaded to:", download_dir)
44
+
45
+ datasets = "all" if args.suite is None else f"libero_{args.suite}"
46
+ print("Datasets to download:", datasets)
47
+
48
+ libero_dataset_download(datasets=datasets, download_dir=download_dir, use_huggingface=True)
49
+ time.sleep(1)
50
+ check_libero_dataset(download_dir=args.download_dir)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()