opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,315 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Base class for pre-trained policies in OpenTau.
17
+
18
+ This module defines the abstract base class `PreTrainedPolicy` which handles
19
+ loading, saving, and basic interface requirements for all policy implementations
20
+ in the OpenTau library. It integrates with Hugging Face Hub for model sharing
21
+ and safetensors for efficient serialization.
22
+ """
23
+
24
+ import abc
25
+ import logging
26
+ import os
27
+ from pathlib import Path
28
+ from typing import Type, TypeVar
29
+
30
+ import packaging
31
+ import safetensors
32
+ import torch
33
+ from huggingface_hub import hf_hub_download
34
+ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
35
+ from huggingface_hub.errors import HfHubHTTPError
36
+ from safetensors.torch import load_model as load_model_as_safetensor
37
+ from safetensors.torch import save_model as save_model_as_safetensor
38
+ from torch import Tensor, nn
39
+
40
+ from opentau.configs.policies import PreTrainedConfig
41
+ from opentau.policies.utils import log_model_loading_keys
42
+ from opentau.utils.hub import HubMixin
43
+
44
+ T = TypeVar("T", bound="PreTrainedPolicy")
45
+
46
+ DEFAULT_POLICY_CARD = """
47
+ ---
48
+ # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
49
+ # Doc / guide: https://huggingface.co/docs/hub/model-cards
50
+ {{ card_data }}
51
+ ---
52
+
53
+ This policy has been pushed to the Hub using [OpenTau](https://github.com/TensorAuto/OpenTau):
54
+ - Docs: {{ docs_url | default("[More Information Needed]", true) }}
55
+ """
56
+
57
+
58
+ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
59
+ """Base class for all policy models in OpenTau.
60
+
61
+ This class extends `nn.Module` and `HubMixin` to provide common functionality
62
+ for policy models, including configuration management, model loading/saving,
63
+ and abstract methods that all policies must implement.
64
+
65
+ Attributes:
66
+ config: The configuration instance for this policy.
67
+ """
68
+
69
+ config_class: None
70
+ """The configuration class associated with this policy. Must be defined in subclasses."""
71
+
72
+ name: None
73
+ """The name of the policy. Must be defined in subclasses."""
74
+
75
+ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
76
+ """Initializes the PreTrainedPolicy.
77
+
78
+ Args:
79
+ config: The configuration object for the policy.
80
+ *inputs: Variable length argument list.
81
+ **kwargs: Arbitrary keyword arguments.
82
+
83
+ Raises:
84
+ ValueError: If `config` is not an instance of `PreTrainedConfig`.
85
+ """
86
+ super().__init__()
87
+ if not isinstance(config, PreTrainedConfig):
88
+ raise ValueError(
89
+ f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
90
+ "`PreTrainedConfig`. To create a model from a pretrained model use "
91
+ f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
92
+ )
93
+ self.config = config
94
+
95
+ def __init_subclass__(cls, **kwargs):
96
+ super().__init_subclass__(**kwargs)
97
+ if not getattr(cls, "config_class", None):
98
+ raise TypeError(f"Class {cls.__name__} must define 'config_class'")
99
+ if not getattr(cls, "name", None):
100
+ raise TypeError(f"Class {cls.__name__} must define 'name'")
101
+
102
+ def _save_pretrained(self, save_directory: Path) -> None:
103
+ """Saves the policy and its configuration to a directory.
104
+
105
+ Args:
106
+ save_directory: The directory to save the policy to.
107
+ """
108
+ self.config._save_pretrained(save_directory)
109
+ model_to_save = self.module if hasattr(self, "module") else self
110
+ save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
111
+
112
+ @classmethod
113
+ def from_pretrained(
114
+ cls: Type[T],
115
+ pretrained_name_or_path: str | Path,
116
+ *,
117
+ config: PreTrainedConfig | None = None,
118
+ force_download: bool = False,
119
+ resume_download: bool | None = None,
120
+ proxies: dict | None = None,
121
+ token: str | bool | None = None,
122
+ cache_dir: str | Path | None = None,
123
+ local_files_only: bool = False,
124
+ revision: str | None = None,
125
+ strict: bool = False,
126
+ **kwargs,
127
+ ) -> T:
128
+ """Loads a pretrained policy from a local path or the Hugging Face Hub.
129
+
130
+ The policy is set in evaluation mode by default using `policy.eval()`
131
+ (dropout modules are deactivated). To train it, you should first set it
132
+ back in training mode with `policy.train()`.
133
+
134
+ Args:
135
+ pretrained_name_or_path: The name or path of the pretrained model.
136
+ config: Optional configuration object. If None, it will be loaded from the
137
+ pretrained model.
138
+ force_download: Whether to force download the model weights.
139
+ resume_download: Whether to resume an interrupted download.
140
+ proxies: Proxy configuration for downloading.
141
+ token: Hugging Face token for authentication.
142
+ cache_dir: Directory to cache downloaded files.
143
+ local_files_only: Whether to only look for local files.
144
+ revision: The specific model version to use (branch, tag, or commit hash).
145
+ strict: Whether to strictly enforce matching keys in state_dict.
146
+ **kwargs: Additional keyword arguments passed to the constructor.
147
+
148
+ Returns:
149
+ T: An instance of the loaded policy.
150
+
151
+ Raises:
152
+ FileNotFoundError: If the model file is not found.
153
+ """
154
+ if config is None:
155
+ config = PreTrainedConfig.from_pretrained(
156
+ pretrained_name_or_path=pretrained_name_or_path,
157
+ force_download=force_download,
158
+ resume_download=resume_download,
159
+ proxies=proxies,
160
+ token=token,
161
+ cache_dir=cache_dir,
162
+ local_files_only=local_files_only,
163
+ revision=revision,
164
+ **kwargs,
165
+ )
166
+ model_id = str(pretrained_name_or_path)
167
+ instance = cls(config, **kwargs)
168
+ if os.path.isdir(model_id):
169
+ print("Loading weights from local directory")
170
+ model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
171
+ policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
172
+ else:
173
+ try:
174
+ model_file = hf_hub_download(
175
+ repo_id=model_id,
176
+ filename=SAFETENSORS_SINGLE_FILE,
177
+ revision=revision,
178
+ cache_dir=cache_dir,
179
+ force_download=force_download,
180
+ proxies=proxies,
181
+ resume_download=resume_download,
182
+ token=token,
183
+ local_files_only=local_files_only,
184
+ )
185
+ policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
186
+ except HfHubHTTPError as e:
187
+ raise FileNotFoundError(
188
+ f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
189
+ ) from e
190
+
191
+ policy.eval()
192
+ return policy
193
+
194
+ def _tile_linear_input_weight(self, state_dict_to_load: dict):
195
+ """Modifies the `state_dict_to_load` in-place by tiling linear layer input weights.
196
+
197
+ This ensures compatibility with the model architecture when weight dimensions don't match exactly,
198
+ typically used for expanding input layers.
199
+
200
+ Args:
201
+ state_dict_to_load: The state dictionary to modify.
202
+ """
203
+ for name, submodule in self.named_modules():
204
+ if not isinstance(submodule, torch.nn.Linear):
205
+ continue
206
+ weight_name = f"{name}.weight"
207
+ if weight_name not in state_dict_to_load:
208
+ continue
209
+ weight = state_dict_to_load[weight_name]
210
+ assert len(weight.shape) == 2, f"Shape of {weight_name} must be 2D, got {weight.shape}"
211
+ out_dim, in_dim = weight.shape
212
+ assert submodule.out_features == out_dim, (
213
+ f"Output of {name} = {submodule.out_features} does not match loaded weight output dim {out_dim}"
214
+ )
215
+ if submodule.in_features == in_dim:
216
+ continue
217
+
218
+ logging.warning(f"Tiling {weight_name} from shape {weight.shape} to {submodule.weight.shape}")
219
+ repeat, remainder = divmod(submodule.in_features, in_dim)
220
+ weight = torch.cat([weight] * repeat + [weight[:, :remainder]], dim=1)
221
+ state_dict_to_load[weight_name] = weight
222
+
223
+ @classmethod
224
+ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
225
+ """Loads model weights from a safetensors file.
226
+
227
+ Args:
228
+ model: The model instance to load weights into.
229
+ model_file: Path to the safetensors file.
230
+ map_location: Device to map the weights to.
231
+ strict: Whether to enforce strict key matching.
232
+
233
+ Returns:
234
+ T: The model with loaded weights.
235
+ """
236
+ # Create base kwargs
237
+ kwargs = {"strict": strict}
238
+
239
+ # Add device parameter for newer versions that support it
240
+ if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
241
+ kwargs["device"] = map_location
242
+
243
+ # Load the model with appropriate kwargs
244
+ missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
245
+ log_model_loading_keys(missing_keys, unexpected_keys)
246
+
247
+ # For older versions, manually move to device if needed
248
+ if "device" not in kwargs and map_location != "cpu":
249
+ logging.warning(
250
+ "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
251
+ " This means that the model is loaded on 'cpu' first and then copied to the device."
252
+ " This leads to a slower loading time."
253
+ " Please update safetensors to version 0.4.3 or above for improved performance."
254
+ )
255
+ model.to(map_location)
256
+ return model
257
+
258
+ # def generate_model_card(self, *args, **kwargs) -> ModelCard:
259
+ # card = ModelCard.from_template(
260
+ # card_data=self._hub_mixin_info.model_card_data,
261
+ # template_str=self._hub_mixin_info.model_card_template,
262
+ # repo_url=self._hub_mixin_info.repo_url,
263
+ # docs_url=self._hub_mixin_info.docs_url,
264
+ # **kwargs,
265
+ # )
266
+ # return card
267
+
268
+ @abc.abstractmethod
269
+ def get_optim_params(self) -> dict:
270
+ """Returns the policy-specific parameters dict to be passed on to the optimizer.
271
+
272
+ Returns:
273
+ dict: A dictionary of parameters to optimize.
274
+ """
275
+ raise NotImplementedError
276
+
277
+ @abc.abstractmethod
278
+ def reset(self):
279
+ """Resets the policy state.
280
+
281
+ This method should be called whenever the environment is reset.
282
+ It handles tasks like clearing caches or resetting internal states for stateful policies.
283
+ """
284
+ raise NotImplementedError
285
+
286
+ # TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
287
+ @abc.abstractmethod
288
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
289
+ """Performs a forward pass of the policy.
290
+
291
+ Args:
292
+ batch: A dictionary of input tensors.
293
+
294
+ Returns:
295
+ tuple[Tensor, dict | None]: A tuple containing:
296
+ - The loss tensor.
297
+ - An optional dictionary of metrics or auxiliary outputs.
298
+ Apart from the loss, items should be logging-friendly native Python types.
299
+ """
300
+ raise NotImplementedError
301
+
302
+ @abc.abstractmethod
303
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
304
+ """Selects an action based on the input batch.
305
+
306
+ This method handles action selection during inference, including
307
+ caching for stateful policies (e.g. RNNs, Transformers).
308
+
309
+ Args:
310
+ batch: A dictionary of observation tensors.
311
+
312
+ Returns:
313
+ Tensor: The selected action(s).
314
+ """
315
+ raise NotImplementedError
@@ -0,0 +1,123 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Utility functions for policy implementations in OpenTau.
19
+
20
+ This module provides helper functions for managing data queues, inspecting model
21
+ properties (device, dtype), determining output shapes, and logging model loading
22
+ information.
23
+ """
24
+
25
+ import logging
26
+ from collections import deque
27
+
28
+ import torch
29
+ from torch import nn
30
+
31
+
32
+ def populate_queues(
33
+ queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
34
+ ) -> dict[str, deque]:
35
+ """Populates queues with batch data.
36
+
37
+ If a queue is not full (e.g. at the start of an episode), it is filled by repeating
38
+ the first observation. Otherwise, the latest observation is appended.
39
+
40
+ Args:
41
+ queues: A dictionary of deques to be populated.
42
+ batch: A dictionary containing the data to add to the queues.
43
+ exclude_keys: A list of keys to exclude from population. Defaults to None.
44
+
45
+ Returns:
46
+ dict[str, deque]: The updated dictionary of queues.
47
+ """
48
+ if exclude_keys is None:
49
+ exclude_keys = []
50
+ for key in batch:
51
+ # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
52
+ # queues have the keys they want).
53
+ if key not in queues or key in exclude_keys:
54
+ continue
55
+ if len(queues[key]) != queues[key].maxlen:
56
+ # initialize by copying the first observation several times until the queue is full
57
+ while len(queues[key]) != queues[key].maxlen:
58
+ queues[key].append(batch[key])
59
+ else:
60
+ # add latest observation to the queue
61
+ queues[key].append(batch[key])
62
+ return queues
63
+
64
+
65
+ def get_device_from_parameters(module: nn.Module) -> torch.device:
66
+ """Get a module's device by checking one of its parameters.
67
+
68
+ Note:
69
+ Assumes that all parameters have the same device.
70
+
71
+ Args:
72
+ module: The PyTorch module to inspect.
73
+
74
+ Returns:
75
+ torch.device: The device of the module's parameters.
76
+ """
77
+ return next(iter(module.parameters())).device
78
+
79
+
80
+ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
81
+ """Get a module's parameter dtype by checking one of its parameters.
82
+
83
+ Note:
84
+ Assumes that all parameters have the same dtype.
85
+
86
+ Args:
87
+ module: The PyTorch module to inspect.
88
+
89
+ Returns:
90
+ torch.dtype: The data type of the module's parameters.
91
+ """
92
+ return next(iter(module.parameters())).dtype
93
+
94
+
95
+ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
96
+ """Calculates the output shape of a PyTorch module given an input shape.
97
+
98
+ Args:
99
+ module: A PyTorch module.
100
+ input_shape: A tuple representing the input shape, e.g., (batch_size, channels, height, width).
101
+
102
+ Returns:
103
+ tuple: The output shape of the module.
104
+ """
105
+ dummy_input = torch.zeros(size=input_shape)
106
+ with torch.inference_mode():
107
+ output = module(dummy_input)
108
+ return tuple(output.shape)
109
+
110
+
111
+ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None:
112
+ """Log missing and unexpected keys when loading a model.
113
+
114
+ Args:
115
+ missing_keys: Keys that were expected but not found.
116
+ unexpected_keys: Keys that were found but not expected.
117
+ """
118
+ if missing_keys:
119
+ # DO NOT UPDATE THIS MESSAGE WITHOUT UPDATING THE REGEX IN .gitlab/scripts/check_pi0_state_keys.py
120
+ logging.warning(f"Missing key(s) when loading model: {missing_keys}")
121
+ if unexpected_keys:
122
+ # DO NOT UPDATE THIS MESSAGE WITHOUT UPDATING THE REGEX IN .gitlab/scripts/check_pi0_state_keys.py
123
+ logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
@@ -0,0 +1,18 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Value Policy Module.
15
+
16
+ This module implements value-based policies for robot control. It includes
17
+ configurations, models, and reward functions for value estimation.
18
+ """
@@ -0,0 +1,170 @@
1
+ # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Configuration module for the Value policy.
17
+
18
+ This module defines the `ValueConfig` class, which handles the configuration parameters
19
+ for the Value policy. It includes settings for the model architecture,
20
+ optimization, scheduling, and data processing.
21
+ """
22
+
23
+ from dataclasses import dataclass, field
24
+
25
+ from opentau.configs.policies import PreTrainedConfig
26
+ from opentau.configs.reward import RewardConfig
27
+ from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
28
+ from opentau.optim.optimizers import AdamWConfig
29
+ from opentau.optim.schedulers import (
30
+ CosineDecayWithWarmupSchedulerConfig,
31
+ LRSchedulerConfig,
32
+ )
33
+
34
+
35
+ @PreTrainedConfig.register_subclass("value")
36
+ @dataclass
37
+ class ValueConfig(PreTrainedConfig):
38
+ """Configuration class for the Value policy.
39
+
40
+ Args:
41
+ n_obs_steps: Number of observation steps to be used.
42
+ chunk_size: The chunk size for the policy.
43
+ normalization_mapping: Mapping of feature types to normalization modes.
44
+ max_state_dim: Maximum dimension for state vectors.
45
+ resize_imgs_with_padding: Tuple indicating the size to resize images with padding.
46
+ empty_cameras: Number of empty cameras to add.
47
+ tokenizer_max_length: Maximum length for the tokenizer.
48
+ reward_config: Configuration for the reward.
49
+ optimizer_lr: Learning rate for the optimizer.
50
+ optimizer_betas: Betas for the optimizer.
51
+ optimizer_eps: Epsilon for the optimizer.
52
+ optimizer_weight_decay: Weight decay for the optimizer.
53
+ scheduler_warmup_steps: Number of warmup steps for the scheduler.
54
+ scheduler_decay_steps: Number of decay steps for the scheduler.
55
+ scheduler_decay_lr: Decay learning rate for the scheduler.
56
+ """
57
+
58
+ # Input / output structure.
59
+ n_obs_steps: int = 1
60
+ chunk_size: int = 50
61
+
62
+ normalization_mapping: dict[str, NormalizationMode] = field(
63
+ default_factory=lambda: {
64
+ "VISUAL": NormalizationMode.IDENTITY,
65
+ "STATE": NormalizationMode.MEAN_STD,
66
+ "VALUE": NormalizationMode.MEAN_STD,
67
+ }
68
+ )
69
+
70
+ # Shorter state vectors will be padded
71
+ max_state_dim: int = 32
72
+
73
+ # Image preprocessing
74
+ resize_imgs_with_padding: tuple[int, int] = (224, 224)
75
+
76
+ # Add empty images.
77
+ empty_cameras: int = 0
78
+
79
+ # Tokenizer
80
+ tokenizer_max_length: int = 48
81
+
82
+ # Reward config
83
+ reward_config: RewardConfig = field(default_factory=RewardConfig)
84
+
85
+ # Training presets
86
+ optimizer_lr: float = 2.5e-5
87
+ optimizer_betas: tuple[float, float] = (0.9, 0.95)
88
+ optimizer_eps: float = 1e-8
89
+ optimizer_weight_decay: float = 1e-10
90
+
91
+ scheduler_warmup_steps: int = 1_000
92
+ scheduler_decay_steps: int = 30_000
93
+ scheduler_decay_lr: float = 2.5e-6
94
+
95
+ def __post_init__(self):
96
+ """Input validation (not exhaustive)."""
97
+ super().__post_init__()
98
+
99
+ if self.n_obs_steps != 1:
100
+ raise ValueError(
101
+ f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
102
+ )
103
+
104
+ max_episode_length = self.reward_config.reward_normalizer + self.reward_config.C_neg
105
+ assert max_episode_length < abs(self.reward_config.C_neg), (
106
+ "Max episode length should be less than the absolute value of C_neg for proper separation of failed and successful episodes"
107
+ )
108
+
109
+ def validate_features(self) -> None:
110
+ """Validates features and adds empty cameras if specified."""
111
+ for i in range(self.empty_cameras):
112
+ key = f"observation.images.empty_camera_{i}"
113
+ empty_camera = PolicyFeature(
114
+ type=FeatureType.VISUAL,
115
+ shape=(3, 480, 640),
116
+ )
117
+ self.input_features[key] = empty_camera
118
+
119
+ def get_optimizer_preset(self) -> AdamWConfig:
120
+ """Returns the optimizer preset configuration.
121
+
122
+ Returns:
123
+ AdamWConfig: The optimizer configuration.
124
+ """
125
+ return AdamWConfig(
126
+ lr=self.optimizer_lr,
127
+ betas=self.optimizer_betas,
128
+ eps=self.optimizer_eps,
129
+ weight_decay=self.optimizer_weight_decay,
130
+ )
131
+
132
+ def get_scheduler_preset(self) -> LRSchedulerConfig:
133
+ """Returns the scheduler preset configuration.
134
+
135
+ Returns:
136
+ CosineDecayWithWarmupSchedulerConfig: The scheduler configuration.
137
+ """
138
+ return CosineDecayWithWarmupSchedulerConfig(
139
+ peak_lr=self.optimizer_lr,
140
+ decay_lr=self.scheduler_decay_lr,
141
+ num_warmup_steps=self.scheduler_warmup_steps,
142
+ num_decay_steps=self.scheduler_decay_steps,
143
+ )
144
+
145
+ @property
146
+ def observation_delta_indices(self) -> None:
147
+ """Returns the observation delta indices.
148
+
149
+ Returns:
150
+ None: Always returns None.
151
+ """
152
+ return None
153
+
154
+ @property
155
+ def action_delta_indices(self) -> list:
156
+ """Returns the action delta indices.
157
+
158
+ Returns:
159
+ list: List of indices from 0 to chunk_size.
160
+ """
161
+ return list(range(self.chunk_size))
162
+
163
+ @property
164
+ def reward_delta_indices(self) -> None:
165
+ """Returns the reward delta indices.
166
+
167
+ Returns:
168
+ None: Always returns None.
169
+ """
170
+ return None