opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Command-line argument parsing and configuration loading utilities.
|
|
16
|
+
|
|
17
|
+
This module provides utilities for parsing command-line arguments, loading
|
|
18
|
+
configuration files from local paths or the HuggingFace Hub, and handling
|
|
19
|
+
plugin discovery and loading. It extends draccus functionality with support
|
|
20
|
+
for path-based configuration loading and plugin system integration.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import importlib
|
|
24
|
+
import inspect
|
|
25
|
+
import pkgutil
|
|
26
|
+
import sys
|
|
27
|
+
from argparse import ArgumentError
|
|
28
|
+
from functools import wraps
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
from typing import Sequence
|
|
31
|
+
|
|
32
|
+
import draccus
|
|
33
|
+
|
|
34
|
+
from opentau.utils.utils import has_method
|
|
35
|
+
|
|
36
|
+
PATH_KEY = "path"
|
|
37
|
+
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
|
38
|
+
draccus.set_config_type("json")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
|
42
|
+
"""Parse arguments from CLI at a given nested attribute level.
|
|
43
|
+
|
|
44
|
+
This function extracts command-line arguments that are nested under a specific
|
|
45
|
+
field name and returns them with the field name prefix removed.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
field_name: The field name to extract nested arguments for.
|
|
49
|
+
args: Sequence of command-line arguments to parse. If None, uses sys.argv[1:].
|
|
50
|
+
Defaults to None.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
List of denested arguments with the field name prefix removed, or None if
|
|
54
|
+
no matching arguments are found.
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
Supposing the main script was called with:
|
|
58
|
+
```
|
|
59
|
+
python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
If called during execution of myscript.py, `get_cli_overrides("arg2")` will
|
|
63
|
+
return:
|
|
64
|
+
```
|
|
65
|
+
["--subarg1=abc", "--subarg2=some/path"]
|
|
66
|
+
```
|
|
67
|
+
"""
|
|
68
|
+
if args is None:
|
|
69
|
+
args = sys.argv[1:]
|
|
70
|
+
attr_level_args = []
|
|
71
|
+
detect_string = f"--{field_name}."
|
|
72
|
+
exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
|
|
73
|
+
for arg in args:
|
|
74
|
+
if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
|
|
75
|
+
denested_arg = f"--{arg.removeprefix(detect_string)}"
|
|
76
|
+
attr_level_args.append(denested_arg)
|
|
77
|
+
|
|
78
|
+
return attr_level_args
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
|
82
|
+
"""Parse a single command-line argument value.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
arg_name: Name of the argument to parse (without the '--' prefix).
|
|
86
|
+
args: Sequence of command-line arguments to parse. If None, uses sys.argv[1:].
|
|
87
|
+
Defaults to None.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
The value of the argument if found, or None if not found.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
For command-line arguments `['--batch_size=32', '--lr=0.001']`:
|
|
94
|
+
- `parse_arg('batch_size')` returns `'32'`
|
|
95
|
+
- `parse_arg('lr')` returns `'0.001'`
|
|
96
|
+
- `parse_arg('missing')` returns `None`
|
|
97
|
+
"""
|
|
98
|
+
if args is None:
|
|
99
|
+
args = sys.argv[1:]
|
|
100
|
+
prefix = f"--{arg_name}="
|
|
101
|
+
for arg in args:
|
|
102
|
+
if arg.startswith(prefix):
|
|
103
|
+
return arg[len(prefix) :]
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
|
108
|
+
"""Parse plugin-related arguments from command-line arguments.
|
|
109
|
+
|
|
110
|
+
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
|
111
|
+
It processes arguments in the format '--key=value' and returns them as a dictionary.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
|
|
115
|
+
cli_args (Sequence[str]): A sequence of command-line arguments to parse.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
dict: A dictionary containing the parsed plugin arguments where:
|
|
119
|
+
- Keys are the argument names (with '--' prefix removed if present)
|
|
120
|
+
- Values are the corresponding argument values
|
|
121
|
+
|
|
122
|
+
Example:
|
|
123
|
+
>>> args = ['--env.discover_packages_path=my_package',
|
|
124
|
+
... '--other_arg=value']
|
|
125
|
+
>>> parse_plugin_args('discover_packages_path', args)
|
|
126
|
+
{'env.discover_packages_path': 'my_package'}
|
|
127
|
+
"""
|
|
128
|
+
plugin_args = {}
|
|
129
|
+
for arg in args:
|
|
130
|
+
if "=" in arg and plugin_arg_suffix in arg:
|
|
131
|
+
key, value = arg.split("=", 1)
|
|
132
|
+
# Remove leading '--' if present
|
|
133
|
+
if key.startswith("--"):
|
|
134
|
+
key = key[2:]
|
|
135
|
+
plugin_args[key] = value
|
|
136
|
+
return plugin_args
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class PluginLoadError(Exception):
|
|
140
|
+
"""Raised when a plugin fails to load."""
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def load_plugin(plugin_path: str) -> None:
|
|
144
|
+
"""Load and initialize a plugin from a given Python package path.
|
|
145
|
+
|
|
146
|
+
This function attempts to load a plugin by importing its package and any submodules.
|
|
147
|
+
Plugin registration is expected to happen during package initialization, i.e. when
|
|
148
|
+
the package is imported the gym environment should be registered and the config classes
|
|
149
|
+
registered with their parents using the `register_subclass` decorator.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
|
|
156
|
+
|
|
157
|
+
Examples:
|
|
158
|
+
>>> load_plugin("external_plugin.core") # Loads plugin from external package
|
|
159
|
+
|
|
160
|
+
Notes:
|
|
161
|
+
- The plugin package should handle its own registration during import
|
|
162
|
+
- All submodules in the plugin package will be imported
|
|
163
|
+
- Implementation follows the plugin discovery pattern from Python packaging guidelines
|
|
164
|
+
|
|
165
|
+
See Also:
|
|
166
|
+
https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
|
|
167
|
+
"""
|
|
168
|
+
try:
|
|
169
|
+
package_module = importlib.import_module(plugin_path, __package__)
|
|
170
|
+
except (ImportError, ModuleNotFoundError) as e:
|
|
171
|
+
raise PluginLoadError(
|
|
172
|
+
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
|
173
|
+
) from e
|
|
174
|
+
|
|
175
|
+
def iter_namespace(ns_pkg):
|
|
176
|
+
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
for _finder, pkg_name, _ispkg in iter_namespace(package_module):
|
|
180
|
+
importlib.import_module(pkg_name)
|
|
181
|
+
except ImportError as e:
|
|
182
|
+
raise PluginLoadError(
|
|
183
|
+
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
|
184
|
+
) from e
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
|
188
|
+
"""Get the path argument for a given field name.
|
|
189
|
+
|
|
190
|
+
This function extracts the path argument for a field, which is typically
|
|
191
|
+
specified as `--field_name.path=some/path`.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
field_name: The field name to get the path argument for.
|
|
195
|
+
args: Sequence of command-line arguments to parse. If None, uses sys.argv[1:].
|
|
196
|
+
Defaults to None.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
The path value if found, or None if not found.
|
|
200
|
+
|
|
201
|
+
Example:
|
|
202
|
+
For `--policy.path=/path/to/config`, `get_path_arg('policy')` returns
|
|
203
|
+
`'/path/to/config'`.
|
|
204
|
+
"""
|
|
205
|
+
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
|
209
|
+
"""Get the type argument for a given field name.
|
|
210
|
+
|
|
211
|
+
This function extracts the type argument for a field, which is typically
|
|
212
|
+
specified as `--field_name.type=SomeType`.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
field_name: The field name to get the type argument for.
|
|
216
|
+
args: Sequence of command-line arguments to parse. If None, uses sys.argv[1:].
|
|
217
|
+
Defaults to None.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
The type value if found, or None if not found.
|
|
221
|
+
|
|
222
|
+
Example:
|
|
223
|
+
For `--policy.type=Pi0Config`, `get_type_arg('policy')` returns `'Pi0Config'`.
|
|
224
|
+
"""
|
|
225
|
+
return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
|
|
229
|
+
"""Filter out arguments matching a specific field name.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
field_to_filter: The field name to filter out (without the '--' prefix).
|
|
233
|
+
args: Sequence of command-line arguments to filter. If None, uses sys.argv[1:].
|
|
234
|
+
Defaults to None.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
List of arguments with the specified field filtered out.
|
|
238
|
+
|
|
239
|
+
Example:
|
|
240
|
+
For `['--batch_size=32', '--lr=0.001', '--batch_size=64']`:
|
|
241
|
+
`filter_arg('batch_size')` returns `['--lr=0.001']`.
|
|
242
|
+
"""
|
|
243
|
+
if args is None:
|
|
244
|
+
args = sys.argv[1:]
|
|
245
|
+
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
|
|
249
|
+
"""Filter command-line arguments related to fields with specific path arguments.
|
|
250
|
+
|
|
251
|
+
This function removes all arguments related to specified fields when a path
|
|
252
|
+
argument is present for those fields. It also validates that path and type
|
|
253
|
+
arguments are not both specified for the same field.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
fields_to_filter: A single field name or a list of field names whose
|
|
257
|
+
arguments need to be filtered.
|
|
258
|
+
args: The sequence of command-line arguments to be filtered. If None,
|
|
259
|
+
uses sys.argv[1:]. Defaults to None.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A filtered list of arguments, with arguments related to the specified
|
|
263
|
+
fields removed.
|
|
264
|
+
|
|
265
|
+
Raises:
|
|
266
|
+
ArgumentError: If both a path argument (e.g., `--field_name.path`) and a
|
|
267
|
+
type argument (e.g., `--field_name.type`) are specified for the same field.
|
|
268
|
+
"""
|
|
269
|
+
if isinstance(fields_to_filter, str):
|
|
270
|
+
fields_to_filter = [fields_to_filter]
|
|
271
|
+
|
|
272
|
+
filtered_args = args
|
|
273
|
+
for field in fields_to_filter:
|
|
274
|
+
if get_path_arg(field, args):
|
|
275
|
+
if get_type_arg(field, args):
|
|
276
|
+
raise ArgumentError(
|
|
277
|
+
argument=None,
|
|
278
|
+
message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
|
|
279
|
+
)
|
|
280
|
+
filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
|
|
281
|
+
|
|
282
|
+
return filtered_args
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def filter_distributed_args(args: Sequence[str] | None = None) -> list[str]:
|
|
286
|
+
"""Filter out distributed training arguments.
|
|
287
|
+
|
|
288
|
+
This function removes arguments that are automatically injected by distributed
|
|
289
|
+
training frameworks (e.g., DeepSpeed, torchrun) but not recognized by the
|
|
290
|
+
custom argument parser.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
args: The sequence of command-line arguments to be filtered. If None,
|
|
294
|
+
uses sys.argv[1:]. Defaults to None.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
A filtered list of arguments with distributed training arguments removed.
|
|
298
|
+
|
|
299
|
+
Note:
|
|
300
|
+
Filtered arguments include: local_rank, node_rank, master_addr, master_port,
|
|
301
|
+
world_size, and rank.
|
|
302
|
+
"""
|
|
303
|
+
if args is None:
|
|
304
|
+
args = sys.argv[1:]
|
|
305
|
+
|
|
306
|
+
# List of distributed training arguments to filter out
|
|
307
|
+
distributed_args = [
|
|
308
|
+
"--local_rank",
|
|
309
|
+
"--local-rank",
|
|
310
|
+
"--node_rank",
|
|
311
|
+
"--node-rank",
|
|
312
|
+
"--master_addr",
|
|
313
|
+
"--master-addr",
|
|
314
|
+
"--master_port",
|
|
315
|
+
"--master-port",
|
|
316
|
+
"--world_size",
|
|
317
|
+
"--world-size",
|
|
318
|
+
"--rank",
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
filtered_args = []
|
|
322
|
+
for arg in args:
|
|
323
|
+
should_filter = False
|
|
324
|
+
for distributed_arg in distributed_args:
|
|
325
|
+
if arg.startswith(f"{distributed_arg}=") or arg == distributed_arg:
|
|
326
|
+
should_filter = True
|
|
327
|
+
break
|
|
328
|
+
if not should_filter:
|
|
329
|
+
filtered_args.append(arg)
|
|
330
|
+
|
|
331
|
+
return filtered_args
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def wrap(config_path: Path | None = None):
|
|
335
|
+
"""Wrap a function to handle configuration parsing with enhanced features.
|
|
336
|
+
|
|
337
|
+
This decorator is similar to `draccus.wrap` but provides three additional features:
|
|
338
|
+
|
|
339
|
+
1. Removes '.path' arguments from CLI to process them later
|
|
340
|
+
2. If a 'config_path' is passed and the main config class has a 'from_pretrained'
|
|
341
|
+
method, initializes it from there to allow fetching configs from the hub directly
|
|
342
|
+
3. Loads plugins specified in CLI arguments. These plugins typically register
|
|
343
|
+
their own subclasses of config classes, so that draccus can find the right
|
|
344
|
+
class to instantiate from the CLI '.type' arguments
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
config_path: Optional path to a configuration file. If provided and the
|
|
348
|
+
config class supports `from_pretrained`, will load from this path.
|
|
349
|
+
Defaults to None.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
A decorator function that wraps the target function with enhanced configuration
|
|
353
|
+
parsing capabilities.
|
|
354
|
+
|
|
355
|
+
Note:
|
|
356
|
+
This is a HACK wrapper around draccus.wrap to add custom functionality.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
def wrapper_outer(fn):
|
|
360
|
+
@wraps(fn)
|
|
361
|
+
def wrapper_inner(*args, **kwargs):
|
|
362
|
+
argspec = inspect.getfullargspec(fn)
|
|
363
|
+
argtype = argspec.annotations[argspec.args[0]]
|
|
364
|
+
if len(args) > 0 and type(args[0]) is argtype:
|
|
365
|
+
cfg = args[0]
|
|
366
|
+
args = args[1:]
|
|
367
|
+
else:
|
|
368
|
+
cli_args = sys.argv[1:]
|
|
369
|
+
# Filter out distributed training arguments first
|
|
370
|
+
cli_args = filter_distributed_args(cli_args)
|
|
371
|
+
plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
|
|
372
|
+
for plugin_cli_arg, plugin_path in plugin_args.items():
|
|
373
|
+
try:
|
|
374
|
+
load_plugin(plugin_path)
|
|
375
|
+
except PluginLoadError as e:
|
|
376
|
+
# add the relevant CLI arg to the error message
|
|
377
|
+
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
|
378
|
+
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
|
379
|
+
config_path_cli = parse_arg("config_path", cli_args)
|
|
380
|
+
if has_method(argtype, "__get_path_fields__"):
|
|
381
|
+
path_fields = argtype.__get_path_fields__()
|
|
382
|
+
cli_args = filter_path_args(path_fields, cli_args)
|
|
383
|
+
if has_method(argtype, "from_pretrained") and config_path_cli:
|
|
384
|
+
cli_args = filter_arg("config_path", cli_args)
|
|
385
|
+
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
|
386
|
+
else:
|
|
387
|
+
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
|
388
|
+
response = fn(cfg, *args, **kwargs)
|
|
389
|
+
return response
|
|
390
|
+
|
|
391
|
+
return wrapper_inner
|
|
392
|
+
|
|
393
|
+
return wrapper_outer
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Policy configuration module.
|
|
16
|
+
|
|
17
|
+
This module provides the base PreTrainedConfig class for policy models, which
|
|
18
|
+
defines the interface and common functionality for all policy configurations.
|
|
19
|
+
It includes support for feature definitions, normalization modes, and loading
|
|
20
|
+
configurations from pretrained models or local paths.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import abc
|
|
24
|
+
import os
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Type, TypeVar
|
|
28
|
+
|
|
29
|
+
import draccus
|
|
30
|
+
from huggingface_hub import hf_hub_download
|
|
31
|
+
from huggingface_hub.constants import CONFIG_NAME
|
|
32
|
+
from huggingface_hub.errors import HfHubHTTPError
|
|
33
|
+
|
|
34
|
+
from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
35
|
+
from opentau.optim.optimizers import OptimizerConfig
|
|
36
|
+
from opentau.optim.schedulers import LRSchedulerConfig
|
|
37
|
+
from opentau.utils.hub import HubMixin
|
|
38
|
+
|
|
39
|
+
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
|
40
|
+
T = TypeVar("T", bound="PreTrainedConfig")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|
45
|
+
"""
|
|
46
|
+
Base configuration class for policy models.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
|
50
|
+
current step and additional steps going back).
|
|
51
|
+
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
|
52
|
+
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
|
53
|
+
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
|
54
|
+
normalization mode to apply.
|
|
55
|
+
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
|
56
|
+
the original scale.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
n_obs_steps: int = 1
|
|
60
|
+
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
|
61
|
+
|
|
62
|
+
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
63
|
+
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
|
64
|
+
|
|
65
|
+
device: str | None = None # cuda | cpu | mps
|
|
66
|
+
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
67
|
+
# automatic gradient scaling is used.
|
|
68
|
+
use_amp: bool = False
|
|
69
|
+
pretrained_path: str | None = None
|
|
70
|
+
|
|
71
|
+
# Mean latency of cloud VLM in seconds.
|
|
72
|
+
cloud_vlm_latency_mean: float = 0.0
|
|
73
|
+
# Standard deviation of latency of cloud VLM in seconds.
|
|
74
|
+
cloud_vlm_latency_std: float = 0.0
|
|
75
|
+
# Lower bound of latency of cloud VLM in seconds.
|
|
76
|
+
cloud_vlm_latency_lower: float = 0.0
|
|
77
|
+
# Upper bound of latency of cloud VLM in seconds.
|
|
78
|
+
cloud_vlm_latency_upper: float = 0.0
|
|
79
|
+
|
|
80
|
+
# Mean latency of action decoder in seconds.
|
|
81
|
+
action_decoder_latency_mean: float = 0.0
|
|
82
|
+
# Standard deviation of latency of action decoder in seconds.
|
|
83
|
+
action_decoder_latency_std: float = 0.0
|
|
84
|
+
# Lower bound of latency of action decoder in seconds.
|
|
85
|
+
action_decoder_latency_lower: float = 0.0
|
|
86
|
+
# Upper bound of latency of action decoder in seconds.
|
|
87
|
+
action_decoder_latency_upper: float = 0.0
|
|
88
|
+
|
|
89
|
+
def __post_init__(self):
|
|
90
|
+
"""Initialize post-creation attributes.
|
|
91
|
+
|
|
92
|
+
This method can be overridden by subclasses to perform additional
|
|
93
|
+
initialization after the dataclass is created.
|
|
94
|
+
"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def type(self) -> str:
|
|
99
|
+
"""Get the type name of this configuration.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
The choice name of this configuration class.
|
|
103
|
+
"""
|
|
104
|
+
return self.get_choice_name(self.__class__)
|
|
105
|
+
|
|
106
|
+
@abc.abstractproperty
|
|
107
|
+
def observation_delta_indices(self) -> list | None:
|
|
108
|
+
"""Get indices for observation delta features.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
List of indices indicating which observation features should be
|
|
112
|
+
treated as deltas, or None if no delta features are used.
|
|
113
|
+
"""
|
|
114
|
+
raise NotImplementedError
|
|
115
|
+
|
|
116
|
+
@abc.abstractproperty
|
|
117
|
+
def action_delta_indices(self) -> list | None:
|
|
118
|
+
"""Get indices for action delta features.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
List of indices indicating which action features should be treated
|
|
122
|
+
as deltas, or None if no delta features are used.
|
|
123
|
+
"""
|
|
124
|
+
raise NotImplementedError
|
|
125
|
+
|
|
126
|
+
@abc.abstractproperty
|
|
127
|
+
def reward_delta_indices(self) -> list | None:
|
|
128
|
+
"""Get indices for reward delta features.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
List of indices indicating which reward features should be treated
|
|
132
|
+
as deltas, or None if no delta features are used.
|
|
133
|
+
"""
|
|
134
|
+
raise NotImplementedError
|
|
135
|
+
|
|
136
|
+
@abc.abstractmethod
|
|
137
|
+
def get_optimizer_preset(self) -> OptimizerConfig:
|
|
138
|
+
"""Get the default optimizer configuration for this policy.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
An OptimizerConfig instance with default settings for this policy type.
|
|
142
|
+
"""
|
|
143
|
+
raise NotImplementedError
|
|
144
|
+
|
|
145
|
+
@abc.abstractmethod
|
|
146
|
+
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
|
147
|
+
"""Get the default learning rate scheduler configuration for this policy.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
An LRSchedulerConfig instance with default settings for this policy type,
|
|
151
|
+
or None if no scheduler should be used.
|
|
152
|
+
"""
|
|
153
|
+
raise NotImplementedError
|
|
154
|
+
|
|
155
|
+
@abc.abstractmethod
|
|
156
|
+
def validate_features(self) -> None:
|
|
157
|
+
"""Validate that the feature configuration is correct.
|
|
158
|
+
|
|
159
|
+
This method should check that all required features are present and
|
|
160
|
+
have valid configurations.
|
|
161
|
+
|
|
162
|
+
Raises:
|
|
163
|
+
ValueError: If the feature configuration is invalid.
|
|
164
|
+
"""
|
|
165
|
+
raise NotImplementedError
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def robot_state_feature(self) -> PolicyFeature | None:
|
|
169
|
+
"""Get the robot state feature from input features.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
The PolicyFeature with type STATE if found, or None otherwise.
|
|
173
|
+
"""
|
|
174
|
+
for _, ft in self.input_features.items():
|
|
175
|
+
if ft.type is FeatureType.STATE:
|
|
176
|
+
return ft
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def env_state_feature(self) -> PolicyFeature | None:
|
|
181
|
+
"""Get the environment state feature from input features.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
The PolicyFeature with type ENV if found, or None otherwise.
|
|
185
|
+
"""
|
|
186
|
+
for _, ft in self.input_features.items():
|
|
187
|
+
if ft.type is FeatureType.ENV:
|
|
188
|
+
return ft
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def image_features(self) -> dict[str, PolicyFeature]:
|
|
193
|
+
"""Get all visual/image features from input features.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Dictionary mapping feature names to PolicyFeature instances with
|
|
197
|
+
type VISUAL.
|
|
198
|
+
"""
|
|
199
|
+
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def action_feature(self) -> PolicyFeature | None:
|
|
203
|
+
"""Get the action feature from output features.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The PolicyFeature with type ACTION if found, or None otherwise.
|
|
207
|
+
"""
|
|
208
|
+
for _, ft in self.output_features.items():
|
|
209
|
+
if ft.type is FeatureType.ACTION:
|
|
210
|
+
return ft
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
def _save_pretrained(self, save_directory: Path) -> None:
|
|
214
|
+
"""Save the configuration to a directory.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
save_directory: Directory path where the configuration will be saved.
|
|
218
|
+
"""
|
|
219
|
+
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
|
220
|
+
draccus.dump(self, f, indent=4)
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def from_pretrained(
|
|
224
|
+
cls: Type[T],
|
|
225
|
+
pretrained_name_or_path: str | Path,
|
|
226
|
+
*,
|
|
227
|
+
force_download: bool = False,
|
|
228
|
+
resume_download: bool = None,
|
|
229
|
+
proxies: dict | None = None,
|
|
230
|
+
token: str | bool | None = None,
|
|
231
|
+
cache_dir: str | Path | None = None,
|
|
232
|
+
local_files_only: bool = False,
|
|
233
|
+
revision: str | None = None,
|
|
234
|
+
**policy_kwargs,
|
|
235
|
+
) -> T:
|
|
236
|
+
"""Load a policy configuration from a pretrained model or local path.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
cls: The class to instantiate.
|
|
240
|
+
pretrained_name_or_path: Can be either:
|
|
241
|
+
|
|
242
|
+
- A string, the model id of a pretrained config hosted inside a model
|
|
243
|
+
repo on huggingface.co.
|
|
244
|
+
- A path to a directory containing a configuration file saved using
|
|
245
|
+
the `_save_pretrained` method.
|
|
246
|
+
force_download: Whether to force (re-)downloading the config files and
|
|
247
|
+
configuration from the HuggingFace Hub. Defaults to False.
|
|
248
|
+
resume_download: Whether to resume downloading the config files.
|
|
249
|
+
Defaults to None.
|
|
250
|
+
proxies: Dictionary of proxies to use for requests. Defaults to None.
|
|
251
|
+
token: The token to use as HTTP bearer authorization. If True, will use
|
|
252
|
+
the token generated when running `huggingface-cli login`. Defaults to None.
|
|
253
|
+
cache_dir: Path to a directory in which a downloaded pretrained model
|
|
254
|
+
configuration should be cached. Defaults to None.
|
|
255
|
+
local_files_only: Whether to only look at local files (i.e., do not try
|
|
256
|
+
to download the config). Defaults to False.
|
|
257
|
+
revision: The specific model version to use. It can be a branch name, a
|
|
258
|
+
tag name, or a commit id. Defaults to None.
|
|
259
|
+
**policy_kwargs: Additional keyword arguments. May include 'cli_overrides'
|
|
260
|
+
for command-line argument overrides.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
An instance of the configuration class loaded from the specified path.
|
|
264
|
+
|
|
265
|
+
Raises:
|
|
266
|
+
FileNotFoundError: If the configuration file is not found on the
|
|
267
|
+
HuggingFace Hub or in the local path.
|
|
268
|
+
"""
|
|
269
|
+
model_id = str(pretrained_name_or_path)
|
|
270
|
+
config_file: str | None = None
|
|
271
|
+
if Path(model_id).is_dir():
|
|
272
|
+
if CONFIG_NAME in os.listdir(model_id):
|
|
273
|
+
config_file = os.path.join(model_id, CONFIG_NAME)
|
|
274
|
+
else:
|
|
275
|
+
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
276
|
+
else:
|
|
277
|
+
try:
|
|
278
|
+
config_file = hf_hub_download(
|
|
279
|
+
repo_id=model_id,
|
|
280
|
+
filename=CONFIG_NAME,
|
|
281
|
+
revision=revision,
|
|
282
|
+
cache_dir=cache_dir,
|
|
283
|
+
force_download=force_download,
|
|
284
|
+
proxies=proxies,
|
|
285
|
+
resume_download=resume_download,
|
|
286
|
+
token=token,
|
|
287
|
+
local_files_only=local_files_only,
|
|
288
|
+
)
|
|
289
|
+
except HfHubHTTPError as e:
|
|
290
|
+
raise FileNotFoundError(
|
|
291
|
+
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
|
292
|
+
) from e
|
|
293
|
+
|
|
294
|
+
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
|
|
295
|
+
# something like --policy.path (in addition to --policy.type)
|
|
296
|
+
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
|
297
|
+
return draccus.parse(cls, config_file, args=cli_overrides)
|