opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Reward configuration module.
|
|
15
|
+
|
|
16
|
+
This module provides the RewardConfig class which contains configuration
|
|
17
|
+
parameters for reward computation in reinforcement learning scenarios.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class RewardConfig:
|
|
25
|
+
"""Configuration for reward computation settings.
|
|
26
|
+
|
|
27
|
+
This configuration is used for reward modeling and computation in reinforcement
|
|
28
|
+
learning scenarios.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
number_of_bins: Number of bins used for reward discretization or binning.
|
|
32
|
+
Defaults to 201.
|
|
33
|
+
C_neg: Negative constant used in reward computation. Defaults to -1000.0.
|
|
34
|
+
reward_normalizer: Normalization factor for rewards. Defaults to 400.
|
|
35
|
+
N_steps_look_ahead: Number of steps to look ahead when computing rewards.
|
|
36
|
+
Defaults to 50.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
number_of_bins: int = 201
|
|
40
|
+
C_neg: float = -1000.0
|
|
41
|
+
reward_normalizer: int = 400
|
|
42
|
+
N_steps_look_ahead: int = 50
|
opentau/configs/train.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Training pipeline configuration module.
|
|
16
|
+
|
|
17
|
+
This module provides the TrainPipelineConfig class which contains all configuration
|
|
18
|
+
parameters needed to run a training pipeline, including dataset settings, policy
|
|
19
|
+
configuration, training hyperparameters, and evaluation settings.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import datetime as dt
|
|
23
|
+
import os
|
|
24
|
+
import sys
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Type
|
|
28
|
+
|
|
29
|
+
import draccus
|
|
30
|
+
from huggingface_hub import hf_hub_download
|
|
31
|
+
from huggingface_hub.errors import HfHubHTTPError
|
|
32
|
+
|
|
33
|
+
from opentau.configs import parser
|
|
34
|
+
from opentau.configs.default import DatasetMixtureConfig, EvalConfig, WandBConfig
|
|
35
|
+
from opentau.configs.policies import PreTrainedConfig
|
|
36
|
+
from opentau.envs.configs import EnvConfig
|
|
37
|
+
from opentau.optim import OptimizerConfig
|
|
38
|
+
from opentau.optim.schedulers import LRSchedulerConfig
|
|
39
|
+
from opentau.utils.hub import HubMixin
|
|
40
|
+
|
|
41
|
+
TRAIN_CONFIG_NAME = "train_config.json"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Somehow, calling `logging.warning()` sets the logger level to WARNING.
|
|
45
|
+
# We print directly to stderr instead.
|
|
46
|
+
def warn(*args, **kwargs):
|
|
47
|
+
"""Print a warning message to stderr.
|
|
48
|
+
|
|
49
|
+
This function is used instead of logging.warning() to avoid setting the logger
|
|
50
|
+
level to WARNING.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
*args: Variable length argument list to print.
|
|
54
|
+
**kwargs: Arbitrary keyword arguments passed to print().
|
|
55
|
+
"""
|
|
56
|
+
print("WARNING:", *args, **kwargs, file=sys.stderr)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class TrainPipelineConfig(HubMixin):
|
|
61
|
+
"""Configuration for the training pipeline.
|
|
62
|
+
|
|
63
|
+
This class contains all configuration parameters needed to run a training
|
|
64
|
+
pipeline, including dataset settings, policy configuration, training hyperparameters,
|
|
65
|
+
and evaluation settings.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataset_mixture: Configuration for the dataset mixture to use during training.
|
|
69
|
+
policy: Configuration for the policy model. If None, must be set via CLI or
|
|
70
|
+
from a pretrained checkpoint.
|
|
71
|
+
output_dir: Directory where all run outputs will be saved. If another training
|
|
72
|
+
session uses the same directory, its contents will be overwritten unless
|
|
73
|
+
`resume` is set to True.
|
|
74
|
+
job_name: Name identifier for the training job. If not provided, defaults to
|
|
75
|
+
the policy type.
|
|
76
|
+
resume: If True, resume a previous run. Requires `output_dir` to point to
|
|
77
|
+
an existing run directory with at least one checkpoint. When resuming,
|
|
78
|
+
the configuration from the checkpoint is used by default, regardless of
|
|
79
|
+
command-line arguments.
|
|
80
|
+
seed: Random seed used for training (model initialization, dataset shuffling)
|
|
81
|
+
and for evaluation environments. Defaults to 1000.
|
|
82
|
+
resolution: Resolution of images (height, width) in data samples. Defaults to (224, 224).
|
|
83
|
+
num_cams: Number of cameras for the cloud VLM in each data sample. Defaults to 2.
|
|
84
|
+
max_state_dim: Maximum dimension of the state vector. Defaults to 32.
|
|
85
|
+
max_action_dim: Maximum dimension of the action vector. Defaults to 32.
|
|
86
|
+
action_chunk: Size of action chunk. Defaults to 50.
|
|
87
|
+
loss_weighting: Dictionary mapping loss type names to their weights.
|
|
88
|
+
Defaults to {"MSE": 1, "CE": 1}.
|
|
89
|
+
num_workers: Number of workers for the dataloader. Defaults to 4.
|
|
90
|
+
batch_size: Total batch size for training. If None, calculated from
|
|
91
|
+
`dataloader_batch_size * gradient_accumulation_steps`.
|
|
92
|
+
gradient_accumulation_steps: Number of gradient accumulation steps.
|
|
93
|
+
Defaults to 1.
|
|
94
|
+
dataloader_batch_size: Batch size used by the dataloader. If None, calculated
|
|
95
|
+
from `batch_size // gradient_accumulation_steps`.
|
|
96
|
+
prefetch_factor: Prefetch factor for the dataloader. If None, uses default.
|
|
97
|
+
steps: Total number of training steps. Defaults to 100,000.
|
|
98
|
+
log_freq: Frequency of logging in training iterations. Defaults to 200.
|
|
99
|
+
save_checkpoint: Whether to save checkpoints during training. Defaults to True.
|
|
100
|
+
save_freq: Frequency of checkpoint saving in training iterations. Checkpoints
|
|
101
|
+
are saved every `save_freq` steps and after the last training step.
|
|
102
|
+
Defaults to 20,000.
|
|
103
|
+
use_policy_training_preset: If True, use optimizer and scheduler presets from
|
|
104
|
+
the policy configuration. Defaults to False.
|
|
105
|
+
optimizer: Configuration for the optimizer. Required if
|
|
106
|
+
`use_policy_training_preset` is False.
|
|
107
|
+
scheduler: Configuration for the learning rate scheduler. Required if
|
|
108
|
+
`use_policy_training_preset` is False.
|
|
109
|
+
wandb: Configuration for Weights & Biases logging. Defaults to WandBConfig().
|
|
110
|
+
debug: If True, set logging level to DEBUG. Defaults to False.
|
|
111
|
+
trace_nans: Enable anomaly detection for debugging NaN/Inf values.
|
|
112
|
+
Warning: causes large computational overhead. Defaults to False.
|
|
113
|
+
env: Optional environment configuration for evaluation. Defaults to None.
|
|
114
|
+
eval: Configuration for evaluation settings. Defaults to EvalConfig().
|
|
115
|
+
eval_freq: Frequency of evaluation in training steps. If 0, evaluation
|
|
116
|
+
is disabled. Defaults to 0.
|
|
117
|
+
last_checkpoint_only: If True, only evaluate the last checkpoint.
|
|
118
|
+
Defaults to True.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
dataset_mixture: DatasetMixtureConfig
|
|
122
|
+
policy: PreTrainedConfig | None = None
|
|
123
|
+
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
|
124
|
+
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
|
125
|
+
output_dir: Path | None = None
|
|
126
|
+
job_name: str | None = None
|
|
127
|
+
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
|
128
|
+
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
|
129
|
+
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
|
130
|
+
# regardless of what's provided with the training command at the time of resumption.
|
|
131
|
+
resume: bool = False
|
|
132
|
+
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
|
133
|
+
# AND for the evaluation environments.
|
|
134
|
+
seed: int | None = 1000
|
|
135
|
+
# parameters for the Standard Data Format
|
|
136
|
+
resolution: tuple[int, int] = (224, 224) # resolution of images (H, W) in data sample
|
|
137
|
+
num_cams: int = 2 # number of cameras for the cloud VLM in each data sample
|
|
138
|
+
max_state_dim: int = 32 # maximum dimension of the state vector
|
|
139
|
+
max_action_dim: int = 32 # maximum dimension of the action vector
|
|
140
|
+
action_chunk: int = 50 # size of action chunk
|
|
141
|
+
loss_weighting: dict[str, float] = field(default_factory=lambda: {"MSE": 1, "CE": 1})
|
|
142
|
+
# Number of workers for the dataloader.
|
|
143
|
+
num_workers: int = 4
|
|
144
|
+
batch_size: int | None = None
|
|
145
|
+
gradient_accumulation_steps: int = 1
|
|
146
|
+
dataloader_batch_size: int | None = None
|
|
147
|
+
# Prefetch factor for the dataloader.
|
|
148
|
+
prefetch_factor: int | None = None
|
|
149
|
+
steps: int = 100_000
|
|
150
|
+
log_freq: int = 200
|
|
151
|
+
save_checkpoint: bool = True
|
|
152
|
+
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
|
153
|
+
save_freq: int = 20_000
|
|
154
|
+
use_policy_training_preset: bool = False
|
|
155
|
+
optimizer: OptimizerConfig | None = None
|
|
156
|
+
scheduler: LRSchedulerConfig | None = None
|
|
157
|
+
wandb: WandBConfig = field(default_factory=WandBConfig)
|
|
158
|
+
# Whether to set the logging level to DEBUG. By default, the logging level will be INFO.
|
|
159
|
+
debug: bool = False
|
|
160
|
+
# Enable anomaly detection for debugging NaN/Inf values (warning: large computational overhead)
|
|
161
|
+
trace_nans: bool = False
|
|
162
|
+
# optional environment and evaluation config for evaluation
|
|
163
|
+
env: EnvConfig | None = None
|
|
164
|
+
eval: EvalConfig | None = field(default_factory=EvalConfig)
|
|
165
|
+
eval_freq: int = 0 # evaluate every eval_freq steps
|
|
166
|
+
last_checkpoint_only: bool = True
|
|
167
|
+
|
|
168
|
+
def __post_init__(self):
|
|
169
|
+
"""Initialize post-creation attributes and validate batch size configuration."""
|
|
170
|
+
self.checkpoint_path = None
|
|
171
|
+
|
|
172
|
+
if self.dataloader_batch_size is None and self.batch_size is None:
|
|
173
|
+
raise ValueError("At least one of `batch_size` and `dataloader_batch_size` should be set.")
|
|
174
|
+
if self.batch_size is None:
|
|
175
|
+
self.batch_size = self.dataloader_batch_size * self.gradient_accumulation_steps
|
|
176
|
+
if self.dataloader_batch_size is None:
|
|
177
|
+
if self.batch_size % self.gradient_accumulation_steps != 0:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
"`batch_size` must be divisible by `gradient_accumulation_steps` "
|
|
180
|
+
"when `dataloader_batch_size` is not set. "
|
|
181
|
+
f"Got {self.batch_size=}, {self.gradient_accumulation_steps=}."
|
|
182
|
+
)
|
|
183
|
+
self.dataloader_batch_size = self.batch_size // self.gradient_accumulation_steps
|
|
184
|
+
if self.dataloader_batch_size * self.gradient_accumulation_steps != self.batch_size:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
"`batch_size` must be equal to `dataloader_batch_size * gradient_accumulation_steps`. "
|
|
187
|
+
f"Got {self.batch_size=}, {self.dataloader_batch_size=}, {self.gradient_accumulation_steps=}."
|
|
188
|
+
)
|
|
189
|
+
assert (
|
|
190
|
+
self.batch_size >= 1 and self.gradient_accumulation_steps >= 1 and self.dataloader_batch_size >= 1
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if self.policy:
|
|
194
|
+
self.policy.max_state_dim = self.max_state_dim
|
|
195
|
+
self.policy.max_action_state = self.max_action_dim
|
|
196
|
+
self.policy.chunk_size = self.action_chunk
|
|
197
|
+
if self.job_name:
|
|
198
|
+
warn(
|
|
199
|
+
"cfg.job_name is deprecated and ignored. Set cfg.wandb.project and/or cfg.wandb.name instead."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def validate(self):
|
|
203
|
+
"""Validate and finalize the training configuration.
|
|
204
|
+
|
|
205
|
+
This method performs several validation and setup tasks:
|
|
206
|
+
- Loads policy configuration from CLI arguments or pretrained path if specified
|
|
207
|
+
- Sets up checkpoint paths for resuming training
|
|
208
|
+
- Validates output directory and creates default if needed
|
|
209
|
+
- Sets up optimizer and scheduler from presets if enabled
|
|
210
|
+
- Updates policy configuration with training parameters
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
ValueError: If required configurations are missing or invalid.
|
|
214
|
+
FileExistsError: If output directory exists and resume is False.
|
|
215
|
+
NotADirectoryError: If config_path for resuming doesn't exist locally.
|
|
216
|
+
"""
|
|
217
|
+
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
|
218
|
+
policy_path = parser.get_path_arg("policy")
|
|
219
|
+
if policy_path:
|
|
220
|
+
# Only load the policy config
|
|
221
|
+
cli_overrides = parser.get_cli_overrides("policy")
|
|
222
|
+
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
223
|
+
self.policy.pretrained_path = policy_path
|
|
224
|
+
elif self.resume:
|
|
225
|
+
# The entire train config is already loaded, we just need to get the checkpoint dir
|
|
226
|
+
config_path = parser.parse_arg("config_path")
|
|
227
|
+
if not config_path:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
|
230
|
+
)
|
|
231
|
+
if not Path(config_path).resolve().exists():
|
|
232
|
+
raise NotADirectoryError(
|
|
233
|
+
f"{config_path=} is expected to be a local path. "
|
|
234
|
+
"Resuming from the hub is not supported for now."
|
|
235
|
+
)
|
|
236
|
+
policy_path = Path(config_path).parent
|
|
237
|
+
self.policy.pretrained_path = policy_path
|
|
238
|
+
self.checkpoint_path = policy_path
|
|
239
|
+
|
|
240
|
+
if not self.job_name:
|
|
241
|
+
self.job_name = f"{self.policy.type}"
|
|
242
|
+
|
|
243
|
+
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
|
244
|
+
raise FileExistsError(
|
|
245
|
+
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
|
|
246
|
+
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
|
247
|
+
)
|
|
248
|
+
elif not self.output_dir:
|
|
249
|
+
now = dt.datetime.now()
|
|
250
|
+
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
|
251
|
+
self.output_dir = Path("outputs/train") / train_dir
|
|
252
|
+
|
|
253
|
+
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
|
254
|
+
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
|
255
|
+
elif self.use_policy_training_preset and not self.resume:
|
|
256
|
+
self.optimizer = self.policy.get_optimizer_preset()
|
|
257
|
+
self.scheduler = self.policy.get_scheduler_preset()
|
|
258
|
+
|
|
259
|
+
if self.policy:
|
|
260
|
+
self.policy.max_state_dim = self.max_state_dim
|
|
261
|
+
self.policy.max_action_state = self.max_action_dim
|
|
262
|
+
self.policy.chunk_size = self.action_chunk
|
|
263
|
+
|
|
264
|
+
@classmethod
|
|
265
|
+
def __get_path_fields__(cls) -> list[str]:
|
|
266
|
+
"""Get list of field names that support path-based loading.
|
|
267
|
+
|
|
268
|
+
This enables the parser to load config from the policy using
|
|
269
|
+
`--policy.path=local/dir`.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
List of field names that support path-based configuration loading.
|
|
273
|
+
"""
|
|
274
|
+
return ["policy"]
|
|
275
|
+
|
|
276
|
+
def to_dict(self) -> dict:
|
|
277
|
+
"""Convert the configuration to a dictionary.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Dictionary representation of the configuration.
|
|
281
|
+
"""
|
|
282
|
+
return draccus.encode(self)
|
|
283
|
+
|
|
284
|
+
def _save_pretrained(self, save_directory: Path) -> None:
|
|
285
|
+
"""Save the configuration to a directory.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
save_directory: Directory path where the configuration will be saved.
|
|
289
|
+
"""
|
|
290
|
+
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
|
291
|
+
draccus.dump(self, f, indent=4)
|
|
292
|
+
|
|
293
|
+
@classmethod
|
|
294
|
+
def from_pretrained(
|
|
295
|
+
cls: Type["TrainPipelineConfig"],
|
|
296
|
+
pretrained_name_or_path: str | Path,
|
|
297
|
+
*,
|
|
298
|
+
force_download: bool = False,
|
|
299
|
+
resume_download: bool = None,
|
|
300
|
+
proxies: dict | None = None,
|
|
301
|
+
token: str | bool | None = None,
|
|
302
|
+
cache_dir: str | Path | None = None,
|
|
303
|
+
local_files_only: bool = False,
|
|
304
|
+
revision: str | None = None,
|
|
305
|
+
**kwargs,
|
|
306
|
+
) -> "TrainPipelineConfig":
|
|
307
|
+
"""Load a training configuration from a pretrained model or local path.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
cls: The class to instantiate.
|
|
311
|
+
pretrained_name_or_path: Can be either:
|
|
312
|
+
|
|
313
|
+
- A string, the model id of a pretrained config hosted inside a model
|
|
314
|
+
repo on huggingface.co.
|
|
315
|
+
- A path to a directory containing a configuration file saved using
|
|
316
|
+
the `_save_pretrained` method.
|
|
317
|
+
- A path to a saved configuration JSON file.
|
|
318
|
+
force_download: Whether to force (re-)downloading the config files and
|
|
319
|
+
configuration from the HuggingFace Hub. Defaults to False.
|
|
320
|
+
resume_download: Whether to resume downloading the config files.
|
|
321
|
+
Defaults to None.
|
|
322
|
+
proxies: Dictionary of proxies to use for requests. Defaults to None.
|
|
323
|
+
token: The token to use as HTTP bearer authorization. If True, will use
|
|
324
|
+
the token generated when running `huggingface-cli login`. Defaults to None.
|
|
325
|
+
cache_dir: Path to a directory in which a downloaded pretrained model
|
|
326
|
+
configuration should be cached. Defaults to None.
|
|
327
|
+
local_files_only: Whether to only look at local files (i.e., do not try
|
|
328
|
+
to download the config). Defaults to False.
|
|
329
|
+
revision: The specific model version to use. It can be a branch name, a
|
|
330
|
+
tag name, or a commit id. Defaults to None.
|
|
331
|
+
**kwargs: Additional keyword arguments passed to the parser.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
An instance of TrainPipelineConfig loaded from the specified path.
|
|
335
|
+
|
|
336
|
+
Raises:
|
|
337
|
+
FileNotFoundError: If the configuration file is not found on the
|
|
338
|
+
HuggingFace Hub or in the local path.
|
|
339
|
+
"""
|
|
340
|
+
model_id = str(pretrained_name_or_path)
|
|
341
|
+
config_file: str | None = None
|
|
342
|
+
if Path(model_id).is_dir():
|
|
343
|
+
if TRAIN_CONFIG_NAME in os.listdir(model_id):
|
|
344
|
+
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
|
|
345
|
+
else:
|
|
346
|
+
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
347
|
+
elif Path(model_id).is_file():
|
|
348
|
+
config_file = model_id
|
|
349
|
+
else:
|
|
350
|
+
try:
|
|
351
|
+
config_file = hf_hub_download(
|
|
352
|
+
repo_id=model_id,
|
|
353
|
+
filename=TRAIN_CONFIG_NAME,
|
|
354
|
+
revision=revision,
|
|
355
|
+
cache_dir=cache_dir,
|
|
356
|
+
force_download=force_download,
|
|
357
|
+
proxies=proxies,
|
|
358
|
+
resume_download=resume_download,
|
|
359
|
+
token=token,
|
|
360
|
+
local_files_only=local_files_only,
|
|
361
|
+
)
|
|
362
|
+
except HfHubHTTPError as e:
|
|
363
|
+
raise FileNotFoundError(
|
|
364
|
+
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
|
365
|
+
) from e
|
|
366
|
+
|
|
367
|
+
cli_args = kwargs.pop("cli_args", [])
|
|
368
|
+
cfg = draccus.parse(cls, config_file, args=cli_args)
|
|
369
|
+
|
|
370
|
+
return cfg
|
opentau/configs/types.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Type definitions for configuration classes.
|
|
16
|
+
|
|
17
|
+
This module provides type definitions used across configuration classes, including
|
|
18
|
+
enumerations for feature types and normalization modes, as well as protocol
|
|
19
|
+
definitions and dataclasses for policy features.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# Note: We subclass str so that serialization is straightforward
|
|
23
|
+
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from enum import Enum
|
|
26
|
+
from typing import Any, Protocol
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FeatureType(str, Enum):
|
|
30
|
+
"""Enumeration of feature types used in policy configurations."""
|
|
31
|
+
|
|
32
|
+
STATE = "STATE"
|
|
33
|
+
"""Robot state features."""
|
|
34
|
+
VISUAL = "VISUAL"
|
|
35
|
+
"""Visual/image features."""
|
|
36
|
+
ENV = "ENV"
|
|
37
|
+
"""Environment state features."""
|
|
38
|
+
ACTION = "ACTION"
|
|
39
|
+
"""Action features."""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NormalizationMode(str, Enum):
|
|
43
|
+
"""Enumeration of normalization modes for features."""
|
|
44
|
+
|
|
45
|
+
MIN_MAX = "MIN_MAX"
|
|
46
|
+
"""Normalize using min-max scaling."""
|
|
47
|
+
MEAN_STD = "MEAN_STD"
|
|
48
|
+
"""Normalize using mean and standard deviation."""
|
|
49
|
+
IDENTITY = "IDENTITY"
|
|
50
|
+
"""No normalization (identity transformation)."""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DictLike(Protocol):
|
|
54
|
+
"""Protocol for dictionary-like objects that support item access.
|
|
55
|
+
|
|
56
|
+
This protocol defines the interface for objects that can be accessed
|
|
57
|
+
using dictionary-style indexing with the `[]` operator.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __getitem__(self, key: Any) -> Any: ...
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class PolicyFeature:
|
|
65
|
+
"""Configuration for a policy feature.
|
|
66
|
+
|
|
67
|
+
This class describes a single feature used by a policy, including its
|
|
68
|
+
type and shape information.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
type: The type of feature (STATE, VISUAL, ENV, or ACTION).
|
|
72
|
+
shape: The shape of the feature as a tuple.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
type: FeatureType
|
|
76
|
+
shape: tuple
|
opentau/constants.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Constants used throughout the OpenTau library.
|
|
16
|
+
|
|
17
|
+
This module defines key constants for:
|
|
18
|
+
- Observation and action keys used in datasets and environments
|
|
19
|
+
- File and directory names for checkpoints, training state, and model storage
|
|
20
|
+
- Cache directory configuration for HuggingFace Hub integration
|
|
21
|
+
|
|
22
|
+
These constants ensure consistent naming conventions across the codebase and
|
|
23
|
+
provide a centralized location for configuration values.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import os
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
from huggingface_hub.constants import HF_HOME
|
|
30
|
+
|
|
31
|
+
OBS_STATE = "observation.state"
|
|
32
|
+
OBS_ENV = "observation.environment_state" # TODO: remove
|
|
33
|
+
OBS_ROBOT = "state" # TODO: remove
|
|
34
|
+
OBS_IMAGE = "observation.image" # TODO: remove
|
|
35
|
+
OBS_IMAGES = "observation.images" # TODO: remove
|
|
36
|
+
ACTION = "actions" # TODO: remove
|
|
37
|
+
OBS_ENV_STATE = "observation.environment_state"
|
|
38
|
+
|
|
39
|
+
# files & directories
|
|
40
|
+
CHECKPOINTS_DIR = "checkpoints"
|
|
41
|
+
LAST_CHECKPOINT_LINK = "last"
|
|
42
|
+
PRETRAINED_MODEL_DIR = "pretrained_model"
|
|
43
|
+
TRAINING_STATE_DIR = "training_state"
|
|
44
|
+
RNG_STATE = "rng_state.safetensors"
|
|
45
|
+
TRAINING_STEP = "training_step.json"
|
|
46
|
+
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
|
47
|
+
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
|
48
|
+
SCHEDULER_STATE = "scheduler_state.json"
|
|
49
|
+
|
|
50
|
+
# cache dir
|
|
51
|
+
default_cache_path = Path(HF_HOME) / "opentau"
|
|
52
|
+
HF_OPENTAU_HOME = Path(os.getenv("HF_OPENTAU_HOME", default_cache_path)).expanduser()
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Dataset management and processing utilities for robot learning and vision-language tasks.
|
|
15
|
+
|
|
16
|
+
This module provides a comprehensive toolkit for loading, creating, managing, and
|
|
17
|
+
processing datasets for training vision-language-action (VLA) models. It supports
|
|
18
|
+
both robot learning datasets (with actions and states) and vision-language
|
|
19
|
+
grounding datasets (for multimodal understanding tasks).
|
|
20
|
+
|
|
21
|
+
The module is organized into several key components:
|
|
22
|
+
|
|
23
|
+
- **Core Datasets**: LeRobotDataset for robot learning data with support for
|
|
24
|
+
temporal alignment, multi-modal data, and version compatibility.
|
|
25
|
+
- **Grounding Datasets**: Vision-language datasets (CLEVR, COCO-QA, PIXMO, VSR)
|
|
26
|
+
for training visual understanding without robot actions.
|
|
27
|
+
- **Dataset Mixtures**: WeightedDatasetMixture for combining multiple datasets
|
|
28
|
+
with controlled sampling proportions.
|
|
29
|
+
- **Data Processing**: Utilities for statistics computation, image/video
|
|
30
|
+
handling, transforms, and format standardization.
|
|
31
|
+
- **Factory Functions**: High-level functions for creating datasets and mixtures
|
|
32
|
+
from configuration objects.
|
|
33
|
+
|
|
34
|
+
Key Features:
|
|
35
|
+
|
|
36
|
+
- **HuggingFace Integration**: Seamless loading from HuggingFace Hub with
|
|
37
|
+
automatic version checking and backward compatibility.
|
|
38
|
+
- **Temporal Alignment**: Delta timestamps enable sampling features at
|
|
39
|
+
different time offsets with optional Gaussian noise for data augmentation.
|
|
40
|
+
- **Multi-modal Support**: Handles images, videos, state vectors, actions,
|
|
41
|
+
and text prompts with automatic format conversion.
|
|
42
|
+
- **Weighted Sampling**: Combine heterogeneous datasets with configurable
|
|
43
|
+
sampling weights for balanced training.
|
|
44
|
+
- **Standard Data Format**: Unified data format across all datasets for
|
|
45
|
+
consistent model input/output interfaces.
|
|
46
|
+
- **Statistics Management**: Automatic computation and aggregation of dataset
|
|
47
|
+
statistics for normalization.
|
|
48
|
+
- **Video Handling**: Multiple video backends (torchcodec, pyav, video_reader)
|
|
49
|
+
for efficient frame extraction and encoding.
|
|
50
|
+
- **Asynchronous I/O**: High-performance image writing for real-time data
|
|
51
|
+
recording without blocking.
|
|
52
|
+
|
|
53
|
+
Main Modules:
|
|
54
|
+
|
|
55
|
+
- **lerobot_dataset**: Core dataset implementation for robot learning data.
|
|
56
|
+
- **grounding**: Vision-language grounding datasets (CLEVR, COCO-QA, PIXMO, VSR).
|
|
57
|
+
- **dataset_mixture**: Weighted combination of multiple datasets.
|
|
58
|
+
- **factory**: Factory functions for creating datasets from configurations.
|
|
59
|
+
- **utils**: Utility functions for I/O, metadata management, and validation.
|
|
60
|
+
- **compute_stats**: Statistics computation and aggregation utilities.
|
|
61
|
+
- **transforms**: Image transformation pipelines for data augmentation.
|
|
62
|
+
- **video_utils**: Video encoding, decoding, and metadata extraction.
|
|
63
|
+
- **image_writer**: Asynchronous image writing for high-frequency recording.
|
|
64
|
+
- **sampler**: Episode-aware sampling with boundary frame filtering.
|
|
65
|
+
- **standard_data_format_mapping**: Feature name and loss type mappings.
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
Create a dataset mixture from configuration:
|
|
69
|
+
|
|
70
|
+
>>> from opentau.datasets.factory import make_dataset_mixture
|
|
71
|
+
>>> mixture = make_dataset_mixture(train_cfg)
|
|
72
|
+
>>> dataloader = mixture.get_dataloader()
|
|
73
|
+
|
|
74
|
+
Load a single dataset:
|
|
75
|
+
|
|
76
|
+
>>> from opentau.datasets.factory import make_dataset
|
|
77
|
+
>>> dataset = make_dataset(dataset_cfg, train_cfg)
|
|
78
|
+
|
|
79
|
+
Access grounding datasets:
|
|
80
|
+
|
|
81
|
+
>>> from opentau import available_grounding_datasets
|
|
82
|
+
>>> print(list(available_grounding_datasets.keys()))
|
|
83
|
+
['clevr', 'cocoqa', 'dummy', 'pixmo', 'vsr']
|
|
84
|
+
"""
|