opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""Base class for pre-trained policies in OpenTau.
|
|
17
|
+
|
|
18
|
+
This module defines the abstract base class `PreTrainedPolicy` which handles
|
|
19
|
+
loading, saving, and basic interface requirements for all policy implementations
|
|
20
|
+
in the OpenTau library. It integrates with Hugging Face Hub for model sharing
|
|
21
|
+
and safetensors for efficient serialization.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import abc
|
|
25
|
+
import logging
|
|
26
|
+
import os
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Type, TypeVar
|
|
29
|
+
|
|
30
|
+
import packaging
|
|
31
|
+
import safetensors
|
|
32
|
+
import torch
|
|
33
|
+
from huggingface_hub import hf_hub_download
|
|
34
|
+
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
|
35
|
+
from huggingface_hub.errors import HfHubHTTPError
|
|
36
|
+
from safetensors.torch import load_model as load_model_as_safetensor
|
|
37
|
+
from safetensors.torch import save_model as save_model_as_safetensor
|
|
38
|
+
from torch import Tensor, nn
|
|
39
|
+
|
|
40
|
+
from opentau.configs.policies import PreTrainedConfig
|
|
41
|
+
from opentau.policies.utils import log_model_loading_keys
|
|
42
|
+
from opentau.utils.hub import HubMixin
|
|
43
|
+
|
|
44
|
+
T = TypeVar("T", bound="PreTrainedPolicy")
|
|
45
|
+
|
|
46
|
+
DEFAULT_POLICY_CARD = """
|
|
47
|
+
---
|
|
48
|
+
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
|
49
|
+
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
|
50
|
+
{{ card_data }}
|
|
51
|
+
---
|
|
52
|
+
|
|
53
|
+
This policy has been pushed to the Hub using [OpenTau](https://github.com/TensorAuto/OpenTau):
|
|
54
|
+
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|
59
|
+
"""Base class for all policy models in OpenTau.
|
|
60
|
+
|
|
61
|
+
This class extends `nn.Module` and `HubMixin` to provide common functionality
|
|
62
|
+
for policy models, including configuration management, model loading/saving,
|
|
63
|
+
and abstract methods that all policies must implement.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
config: The configuration instance for this policy.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
config_class: None
|
|
70
|
+
"""The configuration class associated with this policy. Must be defined in subclasses."""
|
|
71
|
+
|
|
72
|
+
name: None
|
|
73
|
+
"""The name of the policy. Must be defined in subclasses."""
|
|
74
|
+
|
|
75
|
+
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
|
|
76
|
+
"""Initializes the PreTrainedPolicy.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
config: The configuration object for the policy.
|
|
80
|
+
*inputs: Variable length argument list.
|
|
81
|
+
**kwargs: Arbitrary keyword arguments.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If `config` is not an instance of `PreTrainedConfig`.
|
|
85
|
+
"""
|
|
86
|
+
super().__init__()
|
|
87
|
+
if not isinstance(config, PreTrainedConfig):
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
|
90
|
+
"`PreTrainedConfig`. To create a model from a pretrained model use "
|
|
91
|
+
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
|
92
|
+
)
|
|
93
|
+
self.config = config
|
|
94
|
+
|
|
95
|
+
def __init_subclass__(cls, **kwargs):
|
|
96
|
+
super().__init_subclass__(**kwargs)
|
|
97
|
+
if not getattr(cls, "config_class", None):
|
|
98
|
+
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
|
99
|
+
if not getattr(cls, "name", None):
|
|
100
|
+
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
|
101
|
+
|
|
102
|
+
def _save_pretrained(self, save_directory: Path) -> None:
|
|
103
|
+
"""Saves the policy and its configuration to a directory.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
save_directory: The directory to save the policy to.
|
|
107
|
+
"""
|
|
108
|
+
self.config._save_pretrained(save_directory)
|
|
109
|
+
model_to_save = self.module if hasattr(self, "module") else self
|
|
110
|
+
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def from_pretrained(
|
|
114
|
+
cls: Type[T],
|
|
115
|
+
pretrained_name_or_path: str | Path,
|
|
116
|
+
*,
|
|
117
|
+
config: PreTrainedConfig | None = None,
|
|
118
|
+
force_download: bool = False,
|
|
119
|
+
resume_download: bool | None = None,
|
|
120
|
+
proxies: dict | None = None,
|
|
121
|
+
token: str | bool | None = None,
|
|
122
|
+
cache_dir: str | Path | None = None,
|
|
123
|
+
local_files_only: bool = False,
|
|
124
|
+
revision: str | None = None,
|
|
125
|
+
strict: bool = False,
|
|
126
|
+
**kwargs,
|
|
127
|
+
) -> T:
|
|
128
|
+
"""Loads a pretrained policy from a local path or the Hugging Face Hub.
|
|
129
|
+
|
|
130
|
+
The policy is set in evaluation mode by default using `policy.eval()`
|
|
131
|
+
(dropout modules are deactivated). To train it, you should first set it
|
|
132
|
+
back in training mode with `policy.train()`.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
pretrained_name_or_path: The name or path of the pretrained model.
|
|
136
|
+
config: Optional configuration object. If None, it will be loaded from the
|
|
137
|
+
pretrained model.
|
|
138
|
+
force_download: Whether to force download the model weights.
|
|
139
|
+
resume_download: Whether to resume an interrupted download.
|
|
140
|
+
proxies: Proxy configuration for downloading.
|
|
141
|
+
token: Hugging Face token for authentication.
|
|
142
|
+
cache_dir: Directory to cache downloaded files.
|
|
143
|
+
local_files_only: Whether to only look for local files.
|
|
144
|
+
revision: The specific model version to use (branch, tag, or commit hash).
|
|
145
|
+
strict: Whether to strictly enforce matching keys in state_dict.
|
|
146
|
+
**kwargs: Additional keyword arguments passed to the constructor.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
T: An instance of the loaded policy.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
FileNotFoundError: If the model file is not found.
|
|
153
|
+
"""
|
|
154
|
+
if config is None:
|
|
155
|
+
config = PreTrainedConfig.from_pretrained(
|
|
156
|
+
pretrained_name_or_path=pretrained_name_or_path,
|
|
157
|
+
force_download=force_download,
|
|
158
|
+
resume_download=resume_download,
|
|
159
|
+
proxies=proxies,
|
|
160
|
+
token=token,
|
|
161
|
+
cache_dir=cache_dir,
|
|
162
|
+
local_files_only=local_files_only,
|
|
163
|
+
revision=revision,
|
|
164
|
+
**kwargs,
|
|
165
|
+
)
|
|
166
|
+
model_id = str(pretrained_name_or_path)
|
|
167
|
+
instance = cls(config, **kwargs)
|
|
168
|
+
if os.path.isdir(model_id):
|
|
169
|
+
print("Loading weights from local directory")
|
|
170
|
+
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
|
171
|
+
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
|
172
|
+
else:
|
|
173
|
+
try:
|
|
174
|
+
model_file = hf_hub_download(
|
|
175
|
+
repo_id=model_id,
|
|
176
|
+
filename=SAFETENSORS_SINGLE_FILE,
|
|
177
|
+
revision=revision,
|
|
178
|
+
cache_dir=cache_dir,
|
|
179
|
+
force_download=force_download,
|
|
180
|
+
proxies=proxies,
|
|
181
|
+
resume_download=resume_download,
|
|
182
|
+
token=token,
|
|
183
|
+
local_files_only=local_files_only,
|
|
184
|
+
)
|
|
185
|
+
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
|
186
|
+
except HfHubHTTPError as e:
|
|
187
|
+
raise FileNotFoundError(
|
|
188
|
+
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
|
189
|
+
) from e
|
|
190
|
+
|
|
191
|
+
policy.eval()
|
|
192
|
+
return policy
|
|
193
|
+
|
|
194
|
+
def _tile_linear_input_weight(self, state_dict_to_load: dict):
|
|
195
|
+
"""Modifies the `state_dict_to_load` in-place by tiling linear layer input weights.
|
|
196
|
+
|
|
197
|
+
This ensures compatibility with the model architecture when weight dimensions don't match exactly,
|
|
198
|
+
typically used for expanding input layers.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
state_dict_to_load: The state dictionary to modify.
|
|
202
|
+
"""
|
|
203
|
+
for name, submodule in self.named_modules():
|
|
204
|
+
if not isinstance(submodule, torch.nn.Linear):
|
|
205
|
+
continue
|
|
206
|
+
weight_name = f"{name}.weight"
|
|
207
|
+
if weight_name not in state_dict_to_load:
|
|
208
|
+
continue
|
|
209
|
+
weight = state_dict_to_load[weight_name]
|
|
210
|
+
assert len(weight.shape) == 2, f"Shape of {weight_name} must be 2D, got {weight.shape}"
|
|
211
|
+
out_dim, in_dim = weight.shape
|
|
212
|
+
assert submodule.out_features == out_dim, (
|
|
213
|
+
f"Output of {name} = {submodule.out_features} does not match loaded weight output dim {out_dim}"
|
|
214
|
+
)
|
|
215
|
+
if submodule.in_features == in_dim:
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
logging.warning(f"Tiling {weight_name} from shape {weight.shape} to {submodule.weight.shape}")
|
|
219
|
+
repeat, remainder = divmod(submodule.in_features, in_dim)
|
|
220
|
+
weight = torch.cat([weight] * repeat + [weight[:, :remainder]], dim=1)
|
|
221
|
+
state_dict_to_load[weight_name] = weight
|
|
222
|
+
|
|
223
|
+
@classmethod
|
|
224
|
+
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
|
225
|
+
"""Loads model weights from a safetensors file.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
model: The model instance to load weights into.
|
|
229
|
+
model_file: Path to the safetensors file.
|
|
230
|
+
map_location: Device to map the weights to.
|
|
231
|
+
strict: Whether to enforce strict key matching.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
T: The model with loaded weights.
|
|
235
|
+
"""
|
|
236
|
+
# Create base kwargs
|
|
237
|
+
kwargs = {"strict": strict}
|
|
238
|
+
|
|
239
|
+
# Add device parameter for newer versions that support it
|
|
240
|
+
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
|
|
241
|
+
kwargs["device"] = map_location
|
|
242
|
+
|
|
243
|
+
# Load the model with appropriate kwargs
|
|
244
|
+
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
|
|
245
|
+
log_model_loading_keys(missing_keys, unexpected_keys)
|
|
246
|
+
|
|
247
|
+
# For older versions, manually move to device if needed
|
|
248
|
+
if "device" not in kwargs and map_location != "cpu":
|
|
249
|
+
logging.warning(
|
|
250
|
+
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
|
251
|
+
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
|
252
|
+
" This leads to a slower loading time."
|
|
253
|
+
" Please update safetensors to version 0.4.3 or above for improved performance."
|
|
254
|
+
)
|
|
255
|
+
model.to(map_location)
|
|
256
|
+
return model
|
|
257
|
+
|
|
258
|
+
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
|
|
259
|
+
# card = ModelCard.from_template(
|
|
260
|
+
# card_data=self._hub_mixin_info.model_card_data,
|
|
261
|
+
# template_str=self._hub_mixin_info.model_card_template,
|
|
262
|
+
# repo_url=self._hub_mixin_info.repo_url,
|
|
263
|
+
# docs_url=self._hub_mixin_info.docs_url,
|
|
264
|
+
# **kwargs,
|
|
265
|
+
# )
|
|
266
|
+
# return card
|
|
267
|
+
|
|
268
|
+
@abc.abstractmethod
|
|
269
|
+
def get_optim_params(self) -> dict:
|
|
270
|
+
"""Returns the policy-specific parameters dict to be passed on to the optimizer.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
dict: A dictionary of parameters to optimize.
|
|
274
|
+
"""
|
|
275
|
+
raise NotImplementedError
|
|
276
|
+
|
|
277
|
+
@abc.abstractmethod
|
|
278
|
+
def reset(self):
|
|
279
|
+
"""Resets the policy state.
|
|
280
|
+
|
|
281
|
+
This method should be called whenever the environment is reset.
|
|
282
|
+
It handles tasks like clearing caches or resetting internal states for stateful policies.
|
|
283
|
+
"""
|
|
284
|
+
raise NotImplementedError
|
|
285
|
+
|
|
286
|
+
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
|
|
287
|
+
@abc.abstractmethod
|
|
288
|
+
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
|
289
|
+
"""Performs a forward pass of the policy.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
batch: A dictionary of input tensors.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
tuple[Tensor, dict | None]: A tuple containing:
|
|
296
|
+
- The loss tensor.
|
|
297
|
+
- An optional dictionary of metrics or auxiliary outputs.
|
|
298
|
+
Apart from the loss, items should be logging-friendly native Python types.
|
|
299
|
+
"""
|
|
300
|
+
raise NotImplementedError
|
|
301
|
+
|
|
302
|
+
@abc.abstractmethod
|
|
303
|
+
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
304
|
+
"""Selects an action based on the input batch.
|
|
305
|
+
|
|
306
|
+
This method handles action selection during inference, including
|
|
307
|
+
caching for stateful policies (e.g. RNNs, Transformers).
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
batch: A dictionary of observation tensors.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Tensor: The selected action(s).
|
|
314
|
+
"""
|
|
315
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""Utility functions for policy implementations in OpenTau.
|
|
19
|
+
|
|
20
|
+
This module provides helper functions for managing data queues, inspecting model
|
|
21
|
+
properties (device, dtype), determining output shapes, and logging model loading
|
|
22
|
+
information.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import logging
|
|
26
|
+
from collections import deque
|
|
27
|
+
|
|
28
|
+
import torch
|
|
29
|
+
from torch import nn
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def populate_queues(
|
|
33
|
+
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
|
|
34
|
+
) -> dict[str, deque]:
|
|
35
|
+
"""Populates queues with batch data.
|
|
36
|
+
|
|
37
|
+
If a queue is not full (e.g. at the start of an episode), it is filled by repeating
|
|
38
|
+
the first observation. Otherwise, the latest observation is appended.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
queues: A dictionary of deques to be populated.
|
|
42
|
+
batch: A dictionary containing the data to add to the queues.
|
|
43
|
+
exclude_keys: A list of keys to exclude from population. Defaults to None.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
dict[str, deque]: The updated dictionary of queues.
|
|
47
|
+
"""
|
|
48
|
+
if exclude_keys is None:
|
|
49
|
+
exclude_keys = []
|
|
50
|
+
for key in batch:
|
|
51
|
+
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
|
|
52
|
+
# queues have the keys they want).
|
|
53
|
+
if key not in queues or key in exclude_keys:
|
|
54
|
+
continue
|
|
55
|
+
if len(queues[key]) != queues[key].maxlen:
|
|
56
|
+
# initialize by copying the first observation several times until the queue is full
|
|
57
|
+
while len(queues[key]) != queues[key].maxlen:
|
|
58
|
+
queues[key].append(batch[key])
|
|
59
|
+
else:
|
|
60
|
+
# add latest observation to the queue
|
|
61
|
+
queues[key].append(batch[key])
|
|
62
|
+
return queues
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
|
66
|
+
"""Get a module's device by checking one of its parameters.
|
|
67
|
+
|
|
68
|
+
Note:
|
|
69
|
+
Assumes that all parameters have the same device.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
module: The PyTorch module to inspect.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
torch.device: The device of the module's parameters.
|
|
76
|
+
"""
|
|
77
|
+
return next(iter(module.parameters())).device
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
|
81
|
+
"""Get a module's parameter dtype by checking one of its parameters.
|
|
82
|
+
|
|
83
|
+
Note:
|
|
84
|
+
Assumes that all parameters have the same dtype.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
module: The PyTorch module to inspect.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
torch.dtype: The data type of the module's parameters.
|
|
91
|
+
"""
|
|
92
|
+
return next(iter(module.parameters())).dtype
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
|
|
96
|
+
"""Calculates the output shape of a PyTorch module given an input shape.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
module: A PyTorch module.
|
|
100
|
+
input_shape: A tuple representing the input shape, e.g., (batch_size, channels, height, width).
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
tuple: The output shape of the module.
|
|
104
|
+
"""
|
|
105
|
+
dummy_input = torch.zeros(size=input_shape)
|
|
106
|
+
with torch.inference_mode():
|
|
107
|
+
output = module(dummy_input)
|
|
108
|
+
return tuple(output.shape)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None:
|
|
112
|
+
"""Log missing and unexpected keys when loading a model.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
missing_keys: Keys that were expected but not found.
|
|
116
|
+
unexpected_keys: Keys that were found but not expected.
|
|
117
|
+
"""
|
|
118
|
+
if missing_keys:
|
|
119
|
+
# DO NOT UPDATE THIS MESSAGE WITHOUT UPDATING THE REGEX IN .gitlab/scripts/check_pi0_state_keys.py
|
|
120
|
+
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
|
121
|
+
if unexpected_keys:
|
|
122
|
+
# DO NOT UPDATE THIS MESSAGE WITHOUT UPDATING THE REGEX IN .gitlab/scripts/check_pi0_state_keys.py
|
|
123
|
+
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Value Policy Module.
|
|
15
|
+
|
|
16
|
+
This module implements value-based policies for robot control. It includes
|
|
17
|
+
configurations, models, and reward functions for value estimation.
|
|
18
|
+
"""
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""Configuration module for the Value policy.
|
|
17
|
+
|
|
18
|
+
This module defines the `ValueConfig` class, which handles the configuration parameters
|
|
19
|
+
for the Value policy. It includes settings for the model architecture,
|
|
20
|
+
optimization, scheduling, and data processing.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
|
|
25
|
+
from opentau.configs.policies import PreTrainedConfig
|
|
26
|
+
from opentau.configs.reward import RewardConfig
|
|
27
|
+
from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
28
|
+
from opentau.optim.optimizers import AdamWConfig
|
|
29
|
+
from opentau.optim.schedulers import (
|
|
30
|
+
CosineDecayWithWarmupSchedulerConfig,
|
|
31
|
+
LRSchedulerConfig,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@PreTrainedConfig.register_subclass("value")
|
|
36
|
+
@dataclass
|
|
37
|
+
class ValueConfig(PreTrainedConfig):
|
|
38
|
+
"""Configuration class for the Value policy.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
n_obs_steps: Number of observation steps to be used.
|
|
42
|
+
chunk_size: The chunk size for the policy.
|
|
43
|
+
normalization_mapping: Mapping of feature types to normalization modes.
|
|
44
|
+
max_state_dim: Maximum dimension for state vectors.
|
|
45
|
+
resize_imgs_with_padding: Tuple indicating the size to resize images with padding.
|
|
46
|
+
empty_cameras: Number of empty cameras to add.
|
|
47
|
+
tokenizer_max_length: Maximum length for the tokenizer.
|
|
48
|
+
reward_config: Configuration for the reward.
|
|
49
|
+
optimizer_lr: Learning rate for the optimizer.
|
|
50
|
+
optimizer_betas: Betas for the optimizer.
|
|
51
|
+
optimizer_eps: Epsilon for the optimizer.
|
|
52
|
+
optimizer_weight_decay: Weight decay for the optimizer.
|
|
53
|
+
scheduler_warmup_steps: Number of warmup steps for the scheduler.
|
|
54
|
+
scheduler_decay_steps: Number of decay steps for the scheduler.
|
|
55
|
+
scheduler_decay_lr: Decay learning rate for the scheduler.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
# Input / output structure.
|
|
59
|
+
n_obs_steps: int = 1
|
|
60
|
+
chunk_size: int = 50
|
|
61
|
+
|
|
62
|
+
normalization_mapping: dict[str, NormalizationMode] = field(
|
|
63
|
+
default_factory=lambda: {
|
|
64
|
+
"VISUAL": NormalizationMode.IDENTITY,
|
|
65
|
+
"STATE": NormalizationMode.MEAN_STD,
|
|
66
|
+
"VALUE": NormalizationMode.MEAN_STD,
|
|
67
|
+
}
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Shorter state vectors will be padded
|
|
71
|
+
max_state_dim: int = 32
|
|
72
|
+
|
|
73
|
+
# Image preprocessing
|
|
74
|
+
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
|
75
|
+
|
|
76
|
+
# Add empty images.
|
|
77
|
+
empty_cameras: int = 0
|
|
78
|
+
|
|
79
|
+
# Tokenizer
|
|
80
|
+
tokenizer_max_length: int = 48
|
|
81
|
+
|
|
82
|
+
# Reward config
|
|
83
|
+
reward_config: RewardConfig = field(default_factory=RewardConfig)
|
|
84
|
+
|
|
85
|
+
# Training presets
|
|
86
|
+
optimizer_lr: float = 2.5e-5
|
|
87
|
+
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
|
88
|
+
optimizer_eps: float = 1e-8
|
|
89
|
+
optimizer_weight_decay: float = 1e-10
|
|
90
|
+
|
|
91
|
+
scheduler_warmup_steps: int = 1_000
|
|
92
|
+
scheduler_decay_steps: int = 30_000
|
|
93
|
+
scheduler_decay_lr: float = 2.5e-6
|
|
94
|
+
|
|
95
|
+
def __post_init__(self):
|
|
96
|
+
"""Input validation (not exhaustive)."""
|
|
97
|
+
super().__post_init__()
|
|
98
|
+
|
|
99
|
+
if self.n_obs_steps != 1:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
max_episode_length = self.reward_config.reward_normalizer + self.reward_config.C_neg
|
|
105
|
+
assert max_episode_length < abs(self.reward_config.C_neg), (
|
|
106
|
+
"Max episode length should be less than the absolute value of C_neg for proper separation of failed and successful episodes"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def validate_features(self) -> None:
|
|
110
|
+
"""Validates features and adds empty cameras if specified."""
|
|
111
|
+
for i in range(self.empty_cameras):
|
|
112
|
+
key = f"observation.images.empty_camera_{i}"
|
|
113
|
+
empty_camera = PolicyFeature(
|
|
114
|
+
type=FeatureType.VISUAL,
|
|
115
|
+
shape=(3, 480, 640),
|
|
116
|
+
)
|
|
117
|
+
self.input_features[key] = empty_camera
|
|
118
|
+
|
|
119
|
+
def get_optimizer_preset(self) -> AdamWConfig:
|
|
120
|
+
"""Returns the optimizer preset configuration.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
AdamWConfig: The optimizer configuration.
|
|
124
|
+
"""
|
|
125
|
+
return AdamWConfig(
|
|
126
|
+
lr=self.optimizer_lr,
|
|
127
|
+
betas=self.optimizer_betas,
|
|
128
|
+
eps=self.optimizer_eps,
|
|
129
|
+
weight_decay=self.optimizer_weight_decay,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def get_scheduler_preset(self) -> LRSchedulerConfig:
|
|
133
|
+
"""Returns the scheduler preset configuration.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
CosineDecayWithWarmupSchedulerConfig: The scheduler configuration.
|
|
137
|
+
"""
|
|
138
|
+
return CosineDecayWithWarmupSchedulerConfig(
|
|
139
|
+
peak_lr=self.optimizer_lr,
|
|
140
|
+
decay_lr=self.scheduler_decay_lr,
|
|
141
|
+
num_warmup_steps=self.scheduler_warmup_steps,
|
|
142
|
+
num_decay_steps=self.scheduler_decay_steps,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def observation_delta_indices(self) -> None:
|
|
147
|
+
"""Returns the observation delta indices.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
None: Always returns None.
|
|
151
|
+
"""
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def action_delta_indices(self) -> list:
|
|
156
|
+
"""Returns the action delta indices.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
list: List of indices from 0 to chunk_size.
|
|
160
|
+
"""
|
|
161
|
+
return list(range(self.chunk_size))
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def reward_delta_indices(self) -> None:
|
|
165
|
+
"""Returns the reward delta indices.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
None: Always returns None.
|
|
169
|
+
"""
|
|
170
|
+
return None
|