opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
opentau/__init__.py ADDED
@@ -0,0 +1,179 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """OpenTau package initialization and registry.
18
+
19
+ This module provides lightweight access to available environments, datasets, and policies
20
+ without importing heavy dependencies. It serves as the main entry point for discovering
21
+ what components are available in the OpenTau library.
22
+
23
+ The module maintains several key registries:
24
+ - `available_envs`: List of supported environment types (e.g., "aloha", "pusht")
25
+ - `available_tasks_per_env`: Mapping of environments to their available tasks
26
+ - `available_datasets_per_env`: Mapping of environments to their compatible datasets
27
+ - `available_real_world_datasets`: List of real-world robot datasets
28
+ - `available_grounding_datasets`: Registry for grounding datasets (populated via decorator)
29
+ - `available_policies`: List of available policy types (e.g., "pi0", "pi05", "value")
30
+ - `available_policies_per_env`: Mapping of environments to their compatible policies
31
+
32
+ Example:
33
+ ```python
34
+ import opentau
35
+ print(opentau.available_envs)
36
+ print(opentau.available_tasks_per_env)
37
+ print(opentau.available_datasets)
38
+ print(opentau.available_datasets_per_env)
39
+ print(opentau.available_real_world_datasets)
40
+ print(opentau.available_policies)
41
+ print(opentau.available_policies_per_env)
42
+ ```
43
+
44
+ When implementing a new dataset, follow these steps:
45
+ - Update `available_datasets_per_env` in `src/opentau/__init__.py`
46
+
47
+ When implementing a new environment (e.g., `gym_aloha`), follow these steps:
48
+ - Update `available_tasks_per_env` and `available_datasets_per_env` in `src/opentau/__init__.py`
49
+
50
+ When implementing a new policy class (e.g., `DiffusionPolicy`), follow these steps:
51
+ - Update `available_policies` and `available_policies_per_env` in `src/opentau/__init__.py`
52
+ - Set the required `name` class attribute
53
+ - Update variables in `tests/test_available.py` by importing your new Policy class
54
+ """
55
+
56
+ import itertools
57
+
58
+ from opentau.__version__ import __version__ # noqa: F401
59
+
60
+ # TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
61
+ # refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
62
+ # a yaml file AND a environment name. The difference should be more obvious.
63
+ available_tasks_per_env = {}
64
+ available_envs = list(available_tasks_per_env.keys())
65
+
66
+ available_datasets_per_env = {}
67
+
68
+ available_real_world_datasets = [
69
+ "lerobot/aloha_mobile_cabinet",
70
+ "lerobot/aloha_mobile_chair",
71
+ "lerobot/aloha_mobile_elevator",
72
+ "lerobot/aloha_mobile_shrimp",
73
+ "lerobot/aloha_mobile_wash_pan",
74
+ "lerobot/aloha_mobile_wipe_wine",
75
+ "lerobot/aloha_static_battery",
76
+ "lerobot/aloha_static_candy",
77
+ "lerobot/aloha_static_coffee",
78
+ "lerobot/aloha_static_coffee_new",
79
+ "lerobot/aloha_static_cups_open",
80
+ "lerobot/aloha_static_fork_pick_up",
81
+ "lerobot/aloha_static_pingpong_test",
82
+ "lerobot/aloha_static_pro_pencil",
83
+ "lerobot/aloha_static_screw_driver",
84
+ "lerobot/aloha_static_tape",
85
+ "lerobot/aloha_static_thread_velcro",
86
+ "lerobot/aloha_static_towel",
87
+ "lerobot/aloha_static_vinh_cup",
88
+ "lerobot/aloha_static_vinh_cup_left",
89
+ "lerobot/aloha_static_ziploc_slide",
90
+ "lerobot/umi_cup_in_the_wild",
91
+ "lerobot/unitreeh1_fold_clothes",
92
+ "lerobot/unitreeh1_rearrange_objects",
93
+ "lerobot/unitreeh1_two_robot_greeting",
94
+ "lerobot/unitreeh1_warehouse",
95
+ "lerobot/nyu_rot_dataset",
96
+ "lerobot/utokyo_saytap",
97
+ "lerobot/imperialcollege_sawyer_wrist_cam",
98
+ "lerobot/utokyo_xarm_bimanual",
99
+ "lerobot/tokyo_u_lsmo",
100
+ "lerobot/utokyo_pr2_opening_fridge",
101
+ "lerobot/cmu_franka_exploration_dataset",
102
+ "lerobot/cmu_stretch",
103
+ "lerobot/asu_table_top",
104
+ "lerobot/utokyo_pr2_tabletop_manipulation",
105
+ "lerobot/utokyo_xarm_pick_and_place",
106
+ "lerobot/ucsd_kitchen_dataset",
107
+ "lerobot/austin_buds_dataset",
108
+ "lerobot/dlr_sara_grid_clamp",
109
+ "lerobot/conq_hose_manipulation",
110
+ "lerobot/columbia_cairlab_pusht_real",
111
+ "lerobot/dlr_sara_pour",
112
+ "lerobot/dlr_edan_shared_control",
113
+ "lerobot/ucsd_pick_and_place_dataset",
114
+ "lerobot/berkeley_cable_routing",
115
+ "lerobot/nyu_franka_play_dataset",
116
+ "lerobot/austin_sirius_dataset",
117
+ "lerobot/cmu_play_fusion",
118
+ "lerobot/berkeley_gnm_sac_son",
119
+ "lerobot/nyu_door_opening_surprising_effectiveness",
120
+ "lerobot/berkeley_fanuc_manipulation",
121
+ "lerobot/jaco_play",
122
+ "lerobot/viola",
123
+ "lerobot/kaist_nonprehensile",
124
+ "lerobot/berkeley_mvp",
125
+ "lerobot/uiuc_d3field",
126
+ "lerobot/berkeley_gnm_recon",
127
+ "lerobot/austin_sailor_dataset",
128
+ "lerobot/utaustin_mutex",
129
+ "lerobot/roboturk",
130
+ "lerobot/stanford_hydra_dataset",
131
+ "lerobot/berkeley_autolab_ur5",
132
+ "lerobot/stanford_robocook",
133
+ "lerobot/toto",
134
+ "lerobot/fmb",
135
+ "lerobot/droid_100",
136
+ "lerobot/berkeley_rpt",
137
+ "lerobot/stanford_kuka_multimodal_dataset",
138
+ "lerobot/iamlab_cmu_pickup_insert",
139
+ "lerobot/taco_play",
140
+ "lerobot/berkeley_gnm_cory_hall",
141
+ "lerobot/usc_cloth_sim",
142
+ ]
143
+
144
+ available_grounding_datasets = {}
145
+
146
+ available_datasets = sorted(
147
+ set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
148
+ )
149
+
150
+ # lists all available policies from `src/opentau/policies`
151
+ available_policies = ["pi0", "pi05", "value"]
152
+
153
+ # keys and values refer to yaml files
154
+ available_policies_per_env = {}
155
+
156
+ env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
157
+ env_dataset_pairs = [
158
+ (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
159
+ ]
160
+ env_dataset_policy_triplets = [
161
+ (env, dataset, policy)
162
+ for env, datasets in available_datasets_per_env.items()
163
+ for dataset in datasets
164
+ for policy in available_policies_per_env[env]
165
+ ]
166
+
167
+
168
+ def registry_factory(global_dict):
169
+ def register(name):
170
+ def decorator(cls):
171
+ global_dict[name] = cls
172
+ return cls
173
+
174
+ return decorator
175
+
176
+ return register
177
+
178
+
179
+ register_grounding_dataset = registry_factory(available_grounding_datasets)
opentau/__version__.py ADDED
@@ -0,0 +1,24 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """To enable `opentau.__version__`"""
18
+
19
+ from importlib.metadata import PackageNotFoundError, version
20
+
21
+ try:
22
+ __version__ = version("opentau")
23
+ except PackageNotFoundError:
24
+ __version__ = "unknown"
@@ -0,0 +1,19 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Configuration module for OpenTau.
16
+
17
+ This module provides configuration classes and utilities for training pipelines,
18
+ datasets, policies, environments, and evaluation settings.
19
+ """
@@ -0,0 +1,297 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Default configuration classes for datasets, evaluation, and logging.
18
+
19
+ This module provides default configuration classes for:
20
+ - Dataset configuration and dataset mixtures
21
+ - Weights & Biases (wandb) logging configuration
22
+ - Evaluation settings and parameters
23
+ """
24
+
25
+ from dataclasses import dataclass, field
26
+
27
+ import draccus
28
+ import numpy as np
29
+ from draccus.parsers.encoding import encode_dataclass
30
+
31
+ from opentau import (
32
+ policies, # noqa: F401
33
+ )
34
+ from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING, LOSS_TYPE_MAPPING
35
+ from opentau.datasets.transforms import ImageTransformsConfig
36
+ from opentau.datasets.video_utils import get_safe_default_codec
37
+
38
+ # --- Custom NumPy encoder registration ---
39
+ # For decoding from cmd/yaml
40
+ draccus.decode.register(np.ndarray, np.asarray)
41
+ # For encoding to yaml
42
+ draccus.encode.register(np.ndarray, lambda x: x.tolist())
43
+
44
+
45
+ @dataclass
46
+ class DatasetConfig:
47
+ """Configuration for a dataset.
48
+
49
+ You may provide a list of datasets here. `train.py` creates them all and
50
+ concatenates them. Note: only data keys common between the datasets are kept.
51
+ Each dataset gets an additional transform that inserts the "dataset_index"
52
+ into the returned item. The index mapping is made according to the order in
53
+ which the datasets are provided.
54
+
55
+ Args:
56
+ repo_id: HuggingFace repository ID for the dataset. Exactly one of
57
+ `repo_id` or `grounding` must be set.
58
+ grounding: Grounding dataset identifier. Exactly one of `repo_id` or
59
+ `grounding` must be set.
60
+ root: Root directory where the dataset will be stored (e.g. 'dataset/path').
61
+ Defaults to None.
62
+ episodes: List of episode indices to use from the dataset. If None, all
63
+ episodes are used. Defaults to None.
64
+ image_transforms: Configuration for image transformations. Defaults to
65
+ ImageTransformsConfig().
66
+ revision: Git revision of the dataset repository to use. Defaults to None.
67
+ use_imagenet_stats: Whether to use ImageNet statistics for normalization.
68
+ Defaults to True.
69
+ video_backend: Video codec backend to use. Defaults to a safe default codec.
70
+ stats: Dictionary of statistics for normalization, keyed by feature name.
71
+ Each value is a dictionary with 'mean' and 'std' arrays. Defaults to None.
72
+ data_features_name_mapping: Optional mapping from dataset feature names to
73
+ standard feature names. Must be provided together with `loss_type_mapping`.
74
+ Defaults to None.
75
+ loss_type_mapping: Optional loss type mapping for the dataset. Must be
76
+ provided together with `data_features_name_mapping`. Defaults to None.
77
+
78
+ Raises:
79
+ ValueError: If both or neither of `repo_id` and `grounding` are set, or
80
+ if only one of `data_features_name_mapping` and `loss_type_mapping`
81
+ is provided.
82
+ """
83
+
84
+ repo_id: str | None = None
85
+ grounding: str | None = None
86
+ # Root directory where the dataset will be stored (e.g. 'dataset/path').
87
+ root: str | None = None
88
+ episodes: list[int] | None = None
89
+ image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
90
+ revision: str | None = None
91
+ use_imagenet_stats: bool = True
92
+ video_backend: str = field(default_factory=get_safe_default_codec)
93
+ stats: dict[str, dict[str, np.ndarray]] | None = None
94
+
95
+ # optional standard data format mapping for the dataset if mapping is not already in standard_data_format_mapping.py
96
+ data_features_name_mapping: dict[str, str] | None = None
97
+ loss_type_mapping: str | None = None
98
+
99
+ def __post_init__(self):
100
+ """Validate dataset configuration and register custom mappings if provided."""
101
+ if (self.repo_id is None) == (self.grounding is None):
102
+ raise ValueError("Exactly one of `repo_id` or `grounding` for Dataset config should be set.")
103
+
104
+ # data_features_name_mapping and loss_type_mapping have to be provided together
105
+ if (self.data_features_name_mapping is None) != (self.loss_type_mapping is None):
106
+ raise ValueError(
107
+ "`data_features_name_mapping` and `loss_type_mapping` have to be provided together."
108
+ )
109
+
110
+ # add data_features_name_mapping and loss_type_mapping to standard_data_format_mapping.py if they are provided
111
+ if self.data_features_name_mapping is not None and self.loss_type_mapping is not None:
112
+ DATA_FEATURES_NAME_MAPPING[self.repo_id] = self.data_features_name_mapping
113
+ LOSS_TYPE_MAPPING[self.repo_id] = self.loss_type_mapping
114
+
115
+
116
+ @dataclass
117
+ class DatasetMixtureConfig:
118
+ """Configuration for a mixture of multiple datasets.
119
+
120
+ This configuration allows combining multiple datasets with specified weights
121
+ for training. The datasets are sampled according to their weights during
122
+ training, and features are resampled to a common action frequency.
123
+
124
+ Args:
125
+ datasets: List of dataset configs to be used in the mixture.
126
+ weights: List of weights for each dataset in the mixture. Must be the
127
+ same length as `datasets`. Defaults to empty list.
128
+ action_freq: Frequency at which actions from the dataset mixture are
129
+ resampled, in Hz. Defaults to 30.0.
130
+ image_resample_strategy: Resample strategy for image features. Must be
131
+ one of 'linear' or 'nearest'. Defaults to 'nearest'.
132
+ vector_resample_strategy: Resample strategy for non-image features, such
133
+ as action or state. Must be one of 'linear' or 'nearest'.
134
+ Defaults to 'nearest'.
135
+
136
+ Raises:
137
+ ValueError: If the length of `weights` doesn't match `datasets`, if
138
+ `action_freq` is not positive, or if resample strategies are invalid.
139
+ """
140
+
141
+ # List of dataset configs to be used in the mixture.
142
+ datasets: list[DatasetConfig] = field(default_factory=list)
143
+ # List of weights for each dataset in the mixture. Must be the same length as `datasets`.
144
+ weights: list[float] = field(default_factory=list)
145
+ # Frequency at which the actions from dataset mixture are resampled, in Hz.
146
+ action_freq: float = 30.0
147
+ # Resample strategy for image features
148
+ image_resample_strategy: str = "nearest"
149
+ # Resample strategy for non-image features, such as action or state
150
+ vector_resample_strategy: str = "nearest"
151
+
152
+ def __post_init__(self):
153
+ """Validate dataset mixture configuration."""
154
+ if len(self.datasets) != len(self.weights):
155
+ raise ValueError("The length of `weights` must match the length of `datasets`.")
156
+ if self.action_freq <= 0:
157
+ raise ValueError(f"`action_freq` must be a positive number, got {self.action_freq}.")
158
+ if self.image_resample_strategy not in ["linear", "nearest"]:
159
+ raise ValueError(
160
+ f"`image_resample_strategy` must be one of ['linear', 'nearest'], got {self.image_resample_strategy}."
161
+ )
162
+ if self.vector_resample_strategy not in ["linear", "nearest"]:
163
+ raise ValueError(
164
+ f"`vector_resample_strategy` must be one of ['linear', 'nearest'], got {self.vector_resample_strategy}."
165
+ )
166
+
167
+
168
+ @dataclass
169
+ class WandBConfig:
170
+ """Configuration for Weights & Biases (wandb) logging.
171
+
172
+ Args:
173
+ enable: Enable Weights & Biases logging. Defaults to False.
174
+ entity: The entity name in Weights & Biases, e.g. your username or your
175
+ team name. Defaults to None.
176
+ project: The project name in Weights & Biases, e.g. "pi0". Defaults to "opentau".
177
+ run_id: If provided, the run will be forked from this run ID. Defaults to None.
178
+ name: Name of the run, shown in the UI. Defaults to None.
179
+ notes: Description of the run, shown in the UI. If None and `enable` is True,
180
+ will prompt the user for input. Defaults to None.
181
+ tags: Tags to be added to the run in the UI, e.g. ["robot", "v1.0"].
182
+ Defaults to empty list.
183
+ group: Used to group runs in the UI, e.g. "experiment_1", "experiment_2".
184
+ Defaults to None.
185
+ job_type: Used to group runs in the UI, e.g. "train", "eval", "test".
186
+ Defaults to None.
187
+ mode: Allowed values: 'online', 'offline', 'disabled'. Defaults to None
188
+ (which uses 'online').
189
+ allow_resume: If True, resume the run from the last checkpoint when
190
+ `run_id` is provided. Defaults to True.
191
+ disable_artifact: Set to True to disable saving an artifact despite
192
+ `training.save_checkpoint=True`. Defaults to False.
193
+ """
194
+
195
+ enable: bool = False # Enable Weights & Biases logging.
196
+ entity: str | None = None # The entity name in Weights & Biases, e.g. your username or your team name
197
+ project: str = "opentau" # The project name in Weights & Biases, e.g. "pi0"
198
+ run_id: str | None = None # If provided, the run will be forked from this run ID.
199
+ name: str | None = None # Name of the run, shown in the UI
200
+ notes: str | None = None # Description of the run, shown in the UI
201
+ tags: list[str] = field(
202
+ default_factory=list
203
+ ) # Tags to be added to the run in the UI, e.g. ["robot", "v1.0"]
204
+ group: str | None = None # Used to group runs in the UI, e.g. "experiment_1", "experiment_2"
205
+ job_type: str | None = None # Used to group runs in the UI, e.g. "train", "eval", "test"
206
+ mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
207
+ allow_resume: bool | None = True # If True, resume the run from the last checkpoint.
208
+ # Set to true to disable saving an artifact despite training.save_checkpoint=True
209
+ disable_artifact: bool = False
210
+
211
+ def __post_init__(self):
212
+ """Prompt user for wandb notes if enabled and notes are not provided."""
213
+ if not self.enable or self.notes is not None:
214
+ return
215
+
216
+ confirm = False
217
+ while not confirm:
218
+ self.notes = input("Please enter a description for wandb logging:\n")
219
+ confirm = input("Confirm (y/N): ").strip().lower() == "y"
220
+
221
+ def to_wandb_kwargs(self, step=None):
222
+ """Convert configuration to keyword arguments for wandb.init().
223
+
224
+ Args:
225
+ step: Optional training step number. If provided along with `run_id`,
226
+ used for resuming or forking runs. Defaults to None.
227
+
228
+ Returns:
229
+ Dictionary of keyword arguments suitable for passing to wandb.init().
230
+ """
231
+ kwargs = encode_dataclass(self)
232
+ excluded_keys = ["enable", "disable_artifact", "project"]
233
+ for ek in excluded_keys:
234
+ kwargs.pop(ek)
235
+
236
+ allow_resume = kwargs.pop("allow_resume")
237
+ run_id = kwargs.pop("run_id", None)
238
+
239
+ # If both `run_id` and `step` are provided, we handle the resuming or forking logic.
240
+ if run_id is not None and step is not None:
241
+ if allow_resume:
242
+ # if `allow_resume`, we resume from the `run_id` if provided.
243
+ kwargs["id"] = run_id
244
+ kwargs["resume"] = "allow"
245
+ else:
246
+ # Without `allow_resume`, we create a new run,
247
+ # and add information about the forked run in the notes.
248
+ # TODO request `kwargs[fork_from]=f"{run_id}?_step={step}"` feature from wandb
249
+ kwargs["notes"] += f"\nForked from run {run_id} at step {step}."
250
+
251
+ return kwargs
252
+
253
+
254
+ @dataclass
255
+ class EvalConfig:
256
+ """Configuration for evaluation settings.
257
+
258
+ Args:
259
+ n_episodes: Number of episodes to run during evaluation. Defaults to 16.
260
+ batch_size: Number of environments to use in a gym.vector.VectorEnv.
261
+ Only used for environments that are not already vectorized.
262
+ Defaults to 16.
263
+ use_async_envs: Whether to use asynchronous environments (multiprocessing).
264
+ Defaults to True.
265
+ max_episodes_rendered: Maximum number of episodes to render as videos.
266
+ Defaults to 16.
267
+ grid_size: Grid dimensions for video summary (rows, cols). If None, will
268
+ be auto-calculated as a square grid. Defaults to None.
269
+ recording_root: Root directory for saving evaluation recordings.
270
+ Defaults to None.
271
+
272
+ Raises:
273
+ ValueError: If `batch_size` is greater than `n_episodes`.
274
+ """
275
+
276
+ n_episodes: int = 16
277
+ # `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv. (Only used for environments that are not already vectorized.)
278
+ batch_size: int = 16
279
+ # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
280
+ use_async_envs: bool = True
281
+ max_episodes_rendered: int = 16
282
+ # Grid dimensions for video summary (rows, cols). If None, will be auto-calculated as square grid.
283
+ grid_size: tuple[int, int] | None = None
284
+
285
+ recording_root: str | None = None
286
+
287
+ def __post_init__(self):
288
+ """Validate evaluation configuration."""
289
+ if self.batch_size > self.n_episodes:
290
+ raise ValueError(
291
+ "The eval batch size is greater than the number of eval episodes "
292
+ f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} "
293
+ f"eval environments will be instantiated, but only {self.n_episodes} will be used. "
294
+ "This might significantly slow down evaluation. To fix this, you should update your command "
295
+ f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
296
+ f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
297
+ )
@@ -0,0 +1,113 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """LIBERO environment configuration module.
15
+
16
+ This module provides configuration classes for LIBERO benchmark evaluation,
17
+ which is a benchmark suite for learning manipulation tasks. It extends the
18
+ base training pipeline configuration with LIBERO-specific evaluation parameters.
19
+ """
20
+
21
+ import os
22
+ from dataclasses import dataclass
23
+
24
+ from libero.libero import benchmark, get_libero_path
25
+
26
+ from opentau.configs.train import TrainPipelineConfig
27
+ from opentau.utils.monkey_patch import torch_load_patch
28
+
29
+ LIBERO_BENCHMARK_DICT = benchmark.get_benchmark_dict()
30
+
31
+
32
+ @dataclass
33
+ class LiberoEnvConfig:
34
+ """Configuration for LIBERO environment evaluation.
35
+
36
+ LIBERO is a benchmark suite for learning manipulation tasks. This configuration
37
+ specifies which task suite and task to run, along with evaluation parameters.
38
+
39
+ Args:
40
+ suite: Task suite to run. Must be 'spatial', 'object', 'goal', or '100'.
41
+ id: Index of the task in the suite to run.
42
+ max_steps: Maximum number of steps to run for each task. Defaults to 1000.
43
+ chunk_usage: Number of actions to perform in each chunk before getting a
44
+ new observation. If None, will be set from the training config's
45
+ `action_chunk`. Defaults to None.
46
+ n_simulations: Number of simulations to run for each task. Defaults to 100.
47
+ video_dir: Directory to save videos of the task execution. Defaults to None.
48
+
49
+ Raises:
50
+ ValueError: If the suite name is invalid or if the task id is out of range
51
+ for the specified suite.
52
+ """
53
+
54
+ suite: str # Task suite to run. Must be 'spatial', 'object', 'goal', or '100'.
55
+ id: int # index of the task in the suite to run.
56
+ max_steps: int = 1000 # maximum number of steps to run for each task.
57
+ chunk_usage: int | None = (
58
+ None # number of actions to perform in each chunk before getting a new observation.
59
+ )
60
+ n_simulations: int = 100 # number of simulations to run for each task.
61
+ video_dir: str = None # directory to save videos of the task execution.
62
+
63
+ def __post_init__(self):
64
+ """Validate LIBERO configuration and initialize task-specific attributes."""
65
+ torch_load_patch()
66
+ suite = f"libero_{self.suite}".lower()
67
+ if suite not in LIBERO_BENCHMARK_DICT:
68
+ raise ValueError(
69
+ f"Invalid suites: '{self.suite}'. "
70
+ f"Available suites are: {[k.replace('libero_', '') for k in LIBERO_BENCHMARK_DICT]}"
71
+ )
72
+ suite = LIBERO_BENCHMARK_DICT[suite]()
73
+ try:
74
+ task = suite.get_task(self.id)
75
+ except IndexError as e:
76
+ raise ValueError(
77
+ f"Invalid task id: {self.id} for suite: {self.suite}. "
78
+ f"Available ids must be from 0 to {len(suite.tasks) - 1}."
79
+ ) from e
80
+
81
+ self.bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
82
+ self.init_states = suite.get_task_init_states(self.id)
83
+ self.task = task
84
+
85
+
86
+ @dataclass
87
+ class TrainConfigWithLiberoEval(TrainPipelineConfig):
88
+ """Training configuration extended with LIBERO evaluation settings.
89
+
90
+ This configuration extends the base training pipeline configuration with
91
+ LIBERO-specific evaluation parameters.
92
+
93
+ Args:
94
+ libero: Configuration for LIBERO environment evaluation. Must be provided.
95
+ Defaults to None.
96
+
97
+ Raises:
98
+ ValueError: If `libero` is None or if `chunk_usage` is not within valid
99
+ range (1 to action_chunk).
100
+ """
101
+
102
+ libero: LiberoEnvConfig = None
103
+
104
+ def __post_init__(self):
105
+ """Validate LIBERO configuration and set default chunk_usage if needed."""
106
+ super().__post_init__()
107
+ if self.libero is None:
108
+ raise ValueError("Libero config must be provided.")
109
+ if self.libero.chunk_usage is None:
110
+ self.libero.chunk_usage = self.action_chunk
111
+ assert 1 <= self.libero.chunk_usage <= self.action_chunk, (
112
+ f"Chunk usage must be between 1 and {self.action_chunk=}, got {self.libero.chunk_usage=}."
113
+ )