opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,393 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Command-line argument parsing and configuration loading utilities.
16
+
17
+ This module provides utilities for parsing command-line arguments, loading
18
+ configuration files from local paths or the HuggingFace Hub, and handling
19
+ plugin discovery and loading. It extends draccus functionality with support
20
+ for path-based configuration loading and plugin system integration.
21
+ """
22
+
23
+ import importlib
24
+ import inspect
25
+ import pkgutil
26
+ import sys
27
+ from argparse import ArgumentError
28
+ from functools import wraps
29
+ from pathlib import Path
30
+ from typing import Sequence
31
+
32
+ import draccus
33
+
34
+ from opentau.utils.utils import has_method
35
+
36
+ PATH_KEY = "path"
37
+ PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
38
+ draccus.set_config_type("json")
39
+
40
+
41
+ def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
42
+ """Parse arguments from CLI at a given nested attribute level.
43
+
44
+ This function extracts command-line arguments that are nested under a specific
45
+ field name and returns them with the field name prefix removed.
46
+
47
+ Args:
48
+ field_name: The field name to extract nested arguments for.
49
+ args: Sequence of command-line arguments to parse. If None, uses sys.argv[1:].
50
+ Defaults to None.
51
+
52
+ Returns:
53
+ List of denested arguments with the field name prefix removed, or None if
54
+ no matching arguments are found.
55
+
56
+ Example:
57
+ Supposing the main script was called with:
58
+ ```
59
+ python myscript.py --arg1=1 --arg2.subarg1=abc --arg2.subarg2=some/path
60
+ ```
61
+
62
+ If called during execution of myscript.py, `get_cli_overrides("arg2")` will
63
+ return:
64
+ ```
65
+ ["--subarg1=abc", "--subarg2=some/path"]
66
+ ```
67
+ """
68
+ if args is None:
69
+ args = sys.argv[1:]
70
+ attr_level_args = []
71
+ detect_string = f"--{field_name}."
72
+ exclude_strings = (f"--{field_name}.{draccus.CHOICE_TYPE_KEY}=", f"--{field_name}.{PATH_KEY}=")
73
+ for arg in args:
74
+ if arg.startswith(detect_string) and not arg.startswith(exclude_strings):
75
+ denested_arg = f"--{arg.removeprefix(detect_string)}"
76
+ attr_level_args.append(denested_arg)
77
+
78
+ return attr_level_args
79
+
80
+
81
+ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
82
+ """Parse a single command-line argument value.
83
+
84
+ Args:
85
+ arg_name: Name of the argument to parse (without the '--' prefix).
86
+ args: Sequence of command-line arguments to parse. If None, uses sys.argv[1:].
87
+ Defaults to None.
88
+
89
+ Returns:
90
+ The value of the argument if found, or None if not found.
91
+
92
+ Example:
93
+ For command-line arguments `['--batch_size=32', '--lr=0.001']`:
94
+ - `parse_arg('batch_size')` returns `'32'`
95
+ - `parse_arg('lr')` returns `'0.001'`
96
+ - `parse_arg('missing')` returns `None`
97
+ """
98
+ if args is None:
99
+ args = sys.argv[1:]
100
+ prefix = f"--{arg_name}="
101
+ for arg in args:
102
+ if arg.startswith(prefix):
103
+ return arg[len(prefix) :]
104
+ return None
105
+
106
+
107
+ def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
108
+ """Parse plugin-related arguments from command-line arguments.
109
+
110
+ This function extracts arguments from command-line arguments that match a specified suffix pattern.
111
+ It processes arguments in the format '--key=value' and returns them as a dictionary.
112
+
113
+ Args:
114
+ plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
115
+ cli_args (Sequence[str]): A sequence of command-line arguments to parse.
116
+
117
+ Returns:
118
+ dict: A dictionary containing the parsed plugin arguments where:
119
+ - Keys are the argument names (with '--' prefix removed if present)
120
+ - Values are the corresponding argument values
121
+
122
+ Example:
123
+ >>> args = ['--env.discover_packages_path=my_package',
124
+ ... '--other_arg=value']
125
+ >>> parse_plugin_args('discover_packages_path', args)
126
+ {'env.discover_packages_path': 'my_package'}
127
+ """
128
+ plugin_args = {}
129
+ for arg in args:
130
+ if "=" in arg and plugin_arg_suffix in arg:
131
+ key, value = arg.split("=", 1)
132
+ # Remove leading '--' if present
133
+ if key.startswith("--"):
134
+ key = key[2:]
135
+ plugin_args[key] = value
136
+ return plugin_args
137
+
138
+
139
+ class PluginLoadError(Exception):
140
+ """Raised when a plugin fails to load."""
141
+
142
+
143
+ def load_plugin(plugin_path: str) -> None:
144
+ """Load and initialize a plugin from a given Python package path.
145
+
146
+ This function attempts to load a plugin by importing its package and any submodules.
147
+ Plugin registration is expected to happen during package initialization, i.e. when
148
+ the package is imported the gym environment should be registered and the config classes
149
+ registered with their parents using the `register_subclass` decorator.
150
+
151
+ Args:
152
+ plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
153
+
154
+ Raises:
155
+ PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
156
+
157
+ Examples:
158
+ >>> load_plugin("external_plugin.core") # Loads plugin from external package
159
+
160
+ Notes:
161
+ - The plugin package should handle its own registration during import
162
+ - All submodules in the plugin package will be imported
163
+ - Implementation follows the plugin discovery pattern from Python packaging guidelines
164
+
165
+ See Also:
166
+ https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
167
+ """
168
+ try:
169
+ package_module = importlib.import_module(plugin_path, __package__)
170
+ except (ImportError, ModuleNotFoundError) as e:
171
+ raise PluginLoadError(
172
+ f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
173
+ ) from e
174
+
175
+ def iter_namespace(ns_pkg):
176
+ return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
177
+
178
+ try:
179
+ for _finder, pkg_name, _ispkg in iter_namespace(package_module):
180
+ importlib.import_module(pkg_name)
181
+ except ImportError as e:
182
+ raise PluginLoadError(
183
+ f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
184
+ ) from e
185
+
186
+
187
+ def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
188
+ """Get the path argument for a given field name.
189
+
190
+ This function extracts the path argument for a field, which is typically
191
+ specified as `--field_name.path=some/path`.
192
+
193
+ Args:
194
+ field_name: The field name to get the path argument for.
195
+ args: Sequence of command-line arguments to parse. If None, uses sys.argv[1:].
196
+ Defaults to None.
197
+
198
+ Returns:
199
+ The path value if found, or None if not found.
200
+
201
+ Example:
202
+ For `--policy.path=/path/to/config`, `get_path_arg('policy')` returns
203
+ `'/path/to/config'`.
204
+ """
205
+ return parse_arg(f"{field_name}.{PATH_KEY}", args)
206
+
207
+
208
+ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
209
+ """Get the type argument for a given field name.
210
+
211
+ This function extracts the type argument for a field, which is typically
212
+ specified as `--field_name.type=SomeType`.
213
+
214
+ Args:
215
+ field_name: The field name to get the type argument for.
216
+ args: Sequence of command-line arguments to parse. If None, uses sys.argv[1:].
217
+ Defaults to None.
218
+
219
+ Returns:
220
+ The type value if found, or None if not found.
221
+
222
+ Example:
223
+ For `--policy.type=Pi0Config`, `get_type_arg('policy')` returns `'Pi0Config'`.
224
+ """
225
+ return parse_arg(f"{field_name}.{draccus.CHOICE_TYPE_KEY}", args)
226
+
227
+
228
+ def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
229
+ """Filter out arguments matching a specific field name.
230
+
231
+ Args:
232
+ field_to_filter: The field name to filter out (without the '--' prefix).
233
+ args: Sequence of command-line arguments to filter. If None, uses sys.argv[1:].
234
+ Defaults to None.
235
+
236
+ Returns:
237
+ List of arguments with the specified field filtered out.
238
+
239
+ Example:
240
+ For `['--batch_size=32', '--lr=0.001', '--batch_size=64']`:
241
+ `filter_arg('batch_size')` returns `['--lr=0.001']`.
242
+ """
243
+ if args is None:
244
+ args = sys.argv[1:]
245
+ return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
246
+
247
+
248
+ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | None = None) -> list[str]:
249
+ """Filter command-line arguments related to fields with specific path arguments.
250
+
251
+ This function removes all arguments related to specified fields when a path
252
+ argument is present for those fields. It also validates that path and type
253
+ arguments are not both specified for the same field.
254
+
255
+ Args:
256
+ fields_to_filter: A single field name or a list of field names whose
257
+ arguments need to be filtered.
258
+ args: The sequence of command-line arguments to be filtered. If None,
259
+ uses sys.argv[1:]. Defaults to None.
260
+
261
+ Returns:
262
+ A filtered list of arguments, with arguments related to the specified
263
+ fields removed.
264
+
265
+ Raises:
266
+ ArgumentError: If both a path argument (e.g., `--field_name.path`) and a
267
+ type argument (e.g., `--field_name.type`) are specified for the same field.
268
+ """
269
+ if isinstance(fields_to_filter, str):
270
+ fields_to_filter = [fields_to_filter]
271
+
272
+ filtered_args = args
273
+ for field in fields_to_filter:
274
+ if get_path_arg(field, args):
275
+ if get_type_arg(field, args):
276
+ raise ArgumentError(
277
+ argument=None,
278
+ message=f"Cannot specify both --{field}.{PATH_KEY} and --{field}.{draccus.CHOICE_TYPE_KEY}",
279
+ )
280
+ filtered_args = [arg for arg in filtered_args if not arg.startswith(f"--{field}.")]
281
+
282
+ return filtered_args
283
+
284
+
285
+ def filter_distributed_args(args: Sequence[str] | None = None) -> list[str]:
286
+ """Filter out distributed training arguments.
287
+
288
+ This function removes arguments that are automatically injected by distributed
289
+ training frameworks (e.g., DeepSpeed, torchrun) but not recognized by the
290
+ custom argument parser.
291
+
292
+ Args:
293
+ args: The sequence of command-line arguments to be filtered. If None,
294
+ uses sys.argv[1:]. Defaults to None.
295
+
296
+ Returns:
297
+ A filtered list of arguments with distributed training arguments removed.
298
+
299
+ Note:
300
+ Filtered arguments include: local_rank, node_rank, master_addr, master_port,
301
+ world_size, and rank.
302
+ """
303
+ if args is None:
304
+ args = sys.argv[1:]
305
+
306
+ # List of distributed training arguments to filter out
307
+ distributed_args = [
308
+ "--local_rank",
309
+ "--local-rank",
310
+ "--node_rank",
311
+ "--node-rank",
312
+ "--master_addr",
313
+ "--master-addr",
314
+ "--master_port",
315
+ "--master-port",
316
+ "--world_size",
317
+ "--world-size",
318
+ "--rank",
319
+ ]
320
+
321
+ filtered_args = []
322
+ for arg in args:
323
+ should_filter = False
324
+ for distributed_arg in distributed_args:
325
+ if arg.startswith(f"{distributed_arg}=") or arg == distributed_arg:
326
+ should_filter = True
327
+ break
328
+ if not should_filter:
329
+ filtered_args.append(arg)
330
+
331
+ return filtered_args
332
+
333
+
334
+ def wrap(config_path: Path | None = None):
335
+ """Wrap a function to handle configuration parsing with enhanced features.
336
+
337
+ This decorator is similar to `draccus.wrap` but provides three additional features:
338
+
339
+ 1. Removes '.path' arguments from CLI to process them later
340
+ 2. If a 'config_path' is passed and the main config class has a 'from_pretrained'
341
+ method, initializes it from there to allow fetching configs from the hub directly
342
+ 3. Loads plugins specified in CLI arguments. These plugins typically register
343
+ their own subclasses of config classes, so that draccus can find the right
344
+ class to instantiate from the CLI '.type' arguments
345
+
346
+ Args:
347
+ config_path: Optional path to a configuration file. If provided and the
348
+ config class supports `from_pretrained`, will load from this path.
349
+ Defaults to None.
350
+
351
+ Returns:
352
+ A decorator function that wraps the target function with enhanced configuration
353
+ parsing capabilities.
354
+
355
+ Note:
356
+ This is a HACK wrapper around draccus.wrap to add custom functionality.
357
+ """
358
+
359
+ def wrapper_outer(fn):
360
+ @wraps(fn)
361
+ def wrapper_inner(*args, **kwargs):
362
+ argspec = inspect.getfullargspec(fn)
363
+ argtype = argspec.annotations[argspec.args[0]]
364
+ if len(args) > 0 and type(args[0]) is argtype:
365
+ cfg = args[0]
366
+ args = args[1:]
367
+ else:
368
+ cli_args = sys.argv[1:]
369
+ # Filter out distributed training arguments first
370
+ cli_args = filter_distributed_args(cli_args)
371
+ plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
372
+ for plugin_cli_arg, plugin_path in plugin_args.items():
373
+ try:
374
+ load_plugin(plugin_path)
375
+ except PluginLoadError as e:
376
+ # add the relevant CLI arg to the error message
377
+ raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
378
+ cli_args = filter_arg(plugin_cli_arg, cli_args)
379
+ config_path_cli = parse_arg("config_path", cli_args)
380
+ if has_method(argtype, "__get_path_fields__"):
381
+ path_fields = argtype.__get_path_fields__()
382
+ cli_args = filter_path_args(path_fields, cli_args)
383
+ if has_method(argtype, "from_pretrained") and config_path_cli:
384
+ cli_args = filter_arg("config_path", cli_args)
385
+ cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
386
+ else:
387
+ cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
388
+ response = fn(cfg, *args, **kwargs)
389
+ return response
390
+
391
+ return wrapper_inner
392
+
393
+ return wrapper_outer
@@ -0,0 +1,297 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Policy configuration module.
16
+
17
+ This module provides the base PreTrainedConfig class for policy models, which
18
+ defines the interface and common functionality for all policy configurations.
19
+ It includes support for feature definitions, normalization modes, and loading
20
+ configurations from pretrained models or local paths.
21
+ """
22
+
23
+ import abc
24
+ import os
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Type, TypeVar
28
+
29
+ import draccus
30
+ from huggingface_hub import hf_hub_download
31
+ from huggingface_hub.constants import CONFIG_NAME
32
+ from huggingface_hub.errors import HfHubHTTPError
33
+
34
+ from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
35
+ from opentau.optim.optimizers import OptimizerConfig
36
+ from opentau.optim.schedulers import LRSchedulerConfig
37
+ from opentau.utils.hub import HubMixin
38
+
39
+ # Generic variable that is either PreTrainedConfig or a subclass thereof
40
+ T = TypeVar("T", bound="PreTrainedConfig")
41
+
42
+
43
+ @dataclass
44
+ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
45
+ """
46
+ Base configuration class for policy models.
47
+
48
+ Args:
49
+ n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
50
+ current step and additional steps going back).
51
+ input_shapes: A dictionary defining the shapes of the input data for the policy.
52
+ output_shapes: A dictionary defining the shapes of the output data for the policy.
53
+ input_normalization_modes: A dictionary with key representing the modality and the value specifies the
54
+ normalization mode to apply.
55
+ output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
56
+ the original scale.
57
+ """
58
+
59
+ n_obs_steps: int = 1
60
+ normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
61
+
62
+ input_features: dict[str, PolicyFeature] = field(default_factory=dict)
63
+ output_features: dict[str, PolicyFeature] = field(default_factory=dict)
64
+
65
+ device: str | None = None # cuda | cpu | mps
66
+ # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
67
+ # automatic gradient scaling is used.
68
+ use_amp: bool = False
69
+ pretrained_path: str | None = None
70
+
71
+ # Mean latency of cloud VLM in seconds.
72
+ cloud_vlm_latency_mean: float = 0.0
73
+ # Standard deviation of latency of cloud VLM in seconds.
74
+ cloud_vlm_latency_std: float = 0.0
75
+ # Lower bound of latency of cloud VLM in seconds.
76
+ cloud_vlm_latency_lower: float = 0.0
77
+ # Upper bound of latency of cloud VLM in seconds.
78
+ cloud_vlm_latency_upper: float = 0.0
79
+
80
+ # Mean latency of action decoder in seconds.
81
+ action_decoder_latency_mean: float = 0.0
82
+ # Standard deviation of latency of action decoder in seconds.
83
+ action_decoder_latency_std: float = 0.0
84
+ # Lower bound of latency of action decoder in seconds.
85
+ action_decoder_latency_lower: float = 0.0
86
+ # Upper bound of latency of action decoder in seconds.
87
+ action_decoder_latency_upper: float = 0.0
88
+
89
+ def __post_init__(self):
90
+ """Initialize post-creation attributes.
91
+
92
+ This method can be overridden by subclasses to perform additional
93
+ initialization after the dataclass is created.
94
+ """
95
+ pass
96
+
97
+ @property
98
+ def type(self) -> str:
99
+ """Get the type name of this configuration.
100
+
101
+ Returns:
102
+ The choice name of this configuration class.
103
+ """
104
+ return self.get_choice_name(self.__class__)
105
+
106
+ @abc.abstractproperty
107
+ def observation_delta_indices(self) -> list | None:
108
+ """Get indices for observation delta features.
109
+
110
+ Returns:
111
+ List of indices indicating which observation features should be
112
+ treated as deltas, or None if no delta features are used.
113
+ """
114
+ raise NotImplementedError
115
+
116
+ @abc.abstractproperty
117
+ def action_delta_indices(self) -> list | None:
118
+ """Get indices for action delta features.
119
+
120
+ Returns:
121
+ List of indices indicating which action features should be treated
122
+ as deltas, or None if no delta features are used.
123
+ """
124
+ raise NotImplementedError
125
+
126
+ @abc.abstractproperty
127
+ def reward_delta_indices(self) -> list | None:
128
+ """Get indices for reward delta features.
129
+
130
+ Returns:
131
+ List of indices indicating which reward features should be treated
132
+ as deltas, or None if no delta features are used.
133
+ """
134
+ raise NotImplementedError
135
+
136
+ @abc.abstractmethod
137
+ def get_optimizer_preset(self) -> OptimizerConfig:
138
+ """Get the default optimizer configuration for this policy.
139
+
140
+ Returns:
141
+ An OptimizerConfig instance with default settings for this policy type.
142
+ """
143
+ raise NotImplementedError
144
+
145
+ @abc.abstractmethod
146
+ def get_scheduler_preset(self) -> LRSchedulerConfig | None:
147
+ """Get the default learning rate scheduler configuration for this policy.
148
+
149
+ Returns:
150
+ An LRSchedulerConfig instance with default settings for this policy type,
151
+ or None if no scheduler should be used.
152
+ """
153
+ raise NotImplementedError
154
+
155
+ @abc.abstractmethod
156
+ def validate_features(self) -> None:
157
+ """Validate that the feature configuration is correct.
158
+
159
+ This method should check that all required features are present and
160
+ have valid configurations.
161
+
162
+ Raises:
163
+ ValueError: If the feature configuration is invalid.
164
+ """
165
+ raise NotImplementedError
166
+
167
+ @property
168
+ def robot_state_feature(self) -> PolicyFeature | None:
169
+ """Get the robot state feature from input features.
170
+
171
+ Returns:
172
+ The PolicyFeature with type STATE if found, or None otherwise.
173
+ """
174
+ for _, ft in self.input_features.items():
175
+ if ft.type is FeatureType.STATE:
176
+ return ft
177
+ return None
178
+
179
+ @property
180
+ def env_state_feature(self) -> PolicyFeature | None:
181
+ """Get the environment state feature from input features.
182
+
183
+ Returns:
184
+ The PolicyFeature with type ENV if found, or None otherwise.
185
+ """
186
+ for _, ft in self.input_features.items():
187
+ if ft.type is FeatureType.ENV:
188
+ return ft
189
+ return None
190
+
191
+ @property
192
+ def image_features(self) -> dict[str, PolicyFeature]:
193
+ """Get all visual/image features from input features.
194
+
195
+ Returns:
196
+ Dictionary mapping feature names to PolicyFeature instances with
197
+ type VISUAL.
198
+ """
199
+ return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
200
+
201
+ @property
202
+ def action_feature(self) -> PolicyFeature | None:
203
+ """Get the action feature from output features.
204
+
205
+ Returns:
206
+ The PolicyFeature with type ACTION if found, or None otherwise.
207
+ """
208
+ for _, ft in self.output_features.items():
209
+ if ft.type is FeatureType.ACTION:
210
+ return ft
211
+ return None
212
+
213
+ def _save_pretrained(self, save_directory: Path) -> None:
214
+ """Save the configuration to a directory.
215
+
216
+ Args:
217
+ save_directory: Directory path where the configuration will be saved.
218
+ """
219
+ with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
220
+ draccus.dump(self, f, indent=4)
221
+
222
+ @classmethod
223
+ def from_pretrained(
224
+ cls: Type[T],
225
+ pretrained_name_or_path: str | Path,
226
+ *,
227
+ force_download: bool = False,
228
+ resume_download: bool = None,
229
+ proxies: dict | None = None,
230
+ token: str | bool | None = None,
231
+ cache_dir: str | Path | None = None,
232
+ local_files_only: bool = False,
233
+ revision: str | None = None,
234
+ **policy_kwargs,
235
+ ) -> T:
236
+ """Load a policy configuration from a pretrained model or local path.
237
+
238
+ Args:
239
+ cls: The class to instantiate.
240
+ pretrained_name_or_path: Can be either:
241
+
242
+ - A string, the model id of a pretrained config hosted inside a model
243
+ repo on huggingface.co.
244
+ - A path to a directory containing a configuration file saved using
245
+ the `_save_pretrained` method.
246
+ force_download: Whether to force (re-)downloading the config files and
247
+ configuration from the HuggingFace Hub. Defaults to False.
248
+ resume_download: Whether to resume downloading the config files.
249
+ Defaults to None.
250
+ proxies: Dictionary of proxies to use for requests. Defaults to None.
251
+ token: The token to use as HTTP bearer authorization. If True, will use
252
+ the token generated when running `huggingface-cli login`. Defaults to None.
253
+ cache_dir: Path to a directory in which a downloaded pretrained model
254
+ configuration should be cached. Defaults to None.
255
+ local_files_only: Whether to only look at local files (i.e., do not try
256
+ to download the config). Defaults to False.
257
+ revision: The specific model version to use. It can be a branch name, a
258
+ tag name, or a commit id. Defaults to None.
259
+ **policy_kwargs: Additional keyword arguments. May include 'cli_overrides'
260
+ for command-line argument overrides.
261
+
262
+ Returns:
263
+ An instance of the configuration class loaded from the specified path.
264
+
265
+ Raises:
266
+ FileNotFoundError: If the configuration file is not found on the
267
+ HuggingFace Hub or in the local path.
268
+ """
269
+ model_id = str(pretrained_name_or_path)
270
+ config_file: str | None = None
271
+ if Path(model_id).is_dir():
272
+ if CONFIG_NAME in os.listdir(model_id):
273
+ config_file = os.path.join(model_id, CONFIG_NAME)
274
+ else:
275
+ print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
276
+ else:
277
+ try:
278
+ config_file = hf_hub_download(
279
+ repo_id=model_id,
280
+ filename=CONFIG_NAME,
281
+ revision=revision,
282
+ cache_dir=cache_dir,
283
+ force_download=force_download,
284
+ proxies=proxies,
285
+ resume_download=resume_download,
286
+ token=token,
287
+ local_files_only=local_files_only,
288
+ )
289
+ except HfHubHTTPError as e:
290
+ raise FileNotFoundError(
291
+ f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
292
+ ) from e
293
+
294
+ # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
295
+ # something like --policy.path (in addition to --policy.type)
296
+ cli_overrides = policy_kwargs.pop("cli_overrides", [])
297
+ return draccus.parse(cls, config_file, args=cli_overrides)