opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
3
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
"""Normalization and Unnormalization utilities for policies.
|
|
18
|
+
|
|
19
|
+
This module provides classes and functions to normalize and unnormalize data
|
|
20
|
+
(e.g., observations and actions) based on statistical properties (mean, std, min, max).
|
|
21
|
+
It handles different normalization modes and supports creating buffers for statistics.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import sys
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
from torch import Tensor, nn
|
|
29
|
+
|
|
30
|
+
from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
31
|
+
|
|
32
|
+
EPS = 1e-8 # Small epsilon value for numerical stability in normalization
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def warn_missing_keys(features: dict[str, PolicyFeature], batch: dict[str, Tensor], mode: str) -> None:
|
|
36
|
+
"""Warns if expected features are missing from the batch.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
features: Dictionary of expected policy features.
|
|
40
|
+
batch: Dictionary containing the data batch.
|
|
41
|
+
mode: The operation mode (e.g., "normalization" or "unnormalization") for the warning message.
|
|
42
|
+
"""
|
|
43
|
+
for missing_key in set(features) - set(batch):
|
|
44
|
+
red_seq = "\033[91m"
|
|
45
|
+
reset_seq = "\033[0m"
|
|
46
|
+
print(
|
|
47
|
+
f"{red_seq}Warning: {missing_key} was missing from the batch during {mode}.{reset_seq}",
|
|
48
|
+
file=sys.stderr,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def create_stats_buffers(
|
|
53
|
+
features: dict[str, PolicyFeature],
|
|
54
|
+
norm_map: dict[str, NormalizationMode],
|
|
55
|
+
stats: dict[str, dict[str, Tensor]] | None = None,
|
|
56
|
+
) -> dict[str, dict[str, nn.ParameterDict]]:
|
|
57
|
+
"""
|
|
58
|
+
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
|
59
|
+
statistics.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
features: Dictionary mapping feature names to PolicyFeature objects.
|
|
63
|
+
norm_map: Dictionary mapping feature types to NormalizationMode.
|
|
64
|
+
stats: Optional dictionary containing pre-computed statistics (mean, std, min, max)
|
|
65
|
+
for each feature. If None, buffers are initialized with infinity.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
|
69
|
+
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ValueError: If stats contain types other than np.ndarray or torch.Tensor.
|
|
73
|
+
"""
|
|
74
|
+
stats_buffers = {}
|
|
75
|
+
|
|
76
|
+
for key, ft in features.items():
|
|
77
|
+
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
|
78
|
+
if norm_mode is NormalizationMode.IDENTITY:
|
|
79
|
+
continue
|
|
80
|
+
|
|
81
|
+
assert isinstance(norm_mode, NormalizationMode)
|
|
82
|
+
|
|
83
|
+
shape = tuple(ft.shape)
|
|
84
|
+
|
|
85
|
+
if ft.type is FeatureType.VISUAL:
|
|
86
|
+
# sanity checks
|
|
87
|
+
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
|
88
|
+
c, h, w = shape
|
|
89
|
+
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
|
90
|
+
# override image shape to be invariant to height and width
|
|
91
|
+
shape = (c, 1, 1)
|
|
92
|
+
|
|
93
|
+
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
|
94
|
+
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
|
95
|
+
# we assert they are not infinity anymore.
|
|
96
|
+
|
|
97
|
+
buffer = {}
|
|
98
|
+
if norm_mode is NormalizationMode.MEAN_STD:
|
|
99
|
+
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
|
100
|
+
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
|
101
|
+
buffer = nn.ParameterDict(
|
|
102
|
+
{
|
|
103
|
+
"mean": nn.Parameter(mean, requires_grad=False),
|
|
104
|
+
"std": nn.Parameter(std, requires_grad=False),
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
|
108
|
+
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
|
109
|
+
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
|
110
|
+
buffer = nn.ParameterDict(
|
|
111
|
+
{
|
|
112
|
+
"min": nn.Parameter(min, requires_grad=False),
|
|
113
|
+
"max": nn.Parameter(max, requires_grad=False),
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
|
118
|
+
if stats:
|
|
119
|
+
if isinstance(stats[key]["mean"], np.ndarray):
|
|
120
|
+
if norm_mode is NormalizationMode.MEAN_STD:
|
|
121
|
+
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
|
122
|
+
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
|
123
|
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
|
124
|
+
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
|
125
|
+
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
|
126
|
+
elif isinstance(stats[key]["mean"], torch.Tensor):
|
|
127
|
+
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
|
128
|
+
# tensors anywhere (for example, when we use the same stats for normalization and
|
|
129
|
+
# unnormalization). See the logic here
|
|
130
|
+
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
|
131
|
+
if norm_mode is NormalizationMode.MEAN_STD:
|
|
132
|
+
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
|
133
|
+
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
|
134
|
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
|
135
|
+
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
|
136
|
+
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
|
137
|
+
else:
|
|
138
|
+
type_ = type(stats[key]["mean"])
|
|
139
|
+
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
|
140
|
+
|
|
141
|
+
stats_buffers[key] = buffer
|
|
142
|
+
return stats_buffers
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _no_stats_error_str(name: str) -> str:
|
|
146
|
+
"""Returns an error message string for missing statistics.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
name: Name of the statistic (e.g., "mean", "std").
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
str: The error message string.
|
|
153
|
+
"""
|
|
154
|
+
return (
|
|
155
|
+
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
|
156
|
+
"pretrained model."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class Normalize(nn.Module):
|
|
161
|
+
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
features: dict[str, PolicyFeature],
|
|
166
|
+
norm_map: dict[str, NormalizationMode],
|
|
167
|
+
stats: dict[str, dict[str, Tensor]] | None = None,
|
|
168
|
+
):
|
|
169
|
+
"""Initializes the Normalize module.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
features: A dictionary where keys are input modalities (e.g. "observation.image") and values
|
|
173
|
+
are their PolicyFeature definitions.
|
|
174
|
+
norm_map: A dictionary where keys are feature types and values are their normalization modes.
|
|
175
|
+
stats: A dictionary where keys are output modalities (e.g. "observation.image")
|
|
176
|
+
and values are dictionaries of statistic types and their values (e.g.
|
|
177
|
+
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
|
178
|
+
training the model for the first time, these statistics will overwrite the default buffers. If
|
|
179
|
+
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
|
180
|
+
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
|
181
|
+
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
|
182
|
+
"""
|
|
183
|
+
super().__init__()
|
|
184
|
+
self.features = features
|
|
185
|
+
self.norm_map = norm_map
|
|
186
|
+
self.stats = stats
|
|
187
|
+
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
|
188
|
+
for key, buffer in stats_buffers.items():
|
|
189
|
+
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
|
190
|
+
|
|
191
|
+
# TODO(rcadene): should we remove torch.no_grad?
|
|
192
|
+
@torch.no_grad
|
|
193
|
+
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
194
|
+
"""Normalizes the batch data.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
batch: Dictionary containing the data to normalize.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
dict[str, Tensor]: The normalized batch data.
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
ValueError: If an unknown normalization mode is encountered.
|
|
204
|
+
"""
|
|
205
|
+
warn_missing_keys(self.features, batch, "normalization")
|
|
206
|
+
batch = dict(batch) # shallow copy avoids mutating the input batch
|
|
207
|
+
for key, ft in self.features.items():
|
|
208
|
+
if key not in batch:
|
|
209
|
+
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
|
213
|
+
if norm_mode is NormalizationMode.IDENTITY:
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
if batch[key].numel() == 0: # skip empty tensors, which won't broadcast well
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
|
220
|
+
|
|
221
|
+
if norm_mode is NormalizationMode.MEAN_STD:
|
|
222
|
+
mean = buffer["mean"]
|
|
223
|
+
std = buffer["std"]
|
|
224
|
+
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
|
225
|
+
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
|
226
|
+
batch[key] = (batch[key] - mean) / (std + EPS)
|
|
227
|
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
|
228
|
+
min = buffer["min"]
|
|
229
|
+
max = buffer["max"]
|
|
230
|
+
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
|
231
|
+
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
|
232
|
+
batch[key] = (batch[key] - min) / (max - min + EPS)
|
|
233
|
+
# normalize to [-1, 1]
|
|
234
|
+
batch[key] = batch[key] * 2 - 1
|
|
235
|
+
else:
|
|
236
|
+
raise ValueError(norm_mode)
|
|
237
|
+
return batch
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class Unnormalize(nn.Module):
|
|
241
|
+
"""
|
|
242
|
+
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
|
243
|
+
original range used by the environment.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
def __init__(
|
|
247
|
+
self,
|
|
248
|
+
features: dict[str, PolicyFeature],
|
|
249
|
+
norm_map: dict[str, NormalizationMode],
|
|
250
|
+
stats: dict[str, dict[str, Tensor]] | None = None,
|
|
251
|
+
):
|
|
252
|
+
"""Initializes the Unnormalize module.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
features: A dictionary where keys are input modalities (e.g. "observation.image") and values
|
|
256
|
+
are their PolicyFeature definitions.
|
|
257
|
+
norm_map: A dictionary where keys are feature types and values are their normalization modes.
|
|
258
|
+
stats: A dictionary where keys are output modalities (e.g. "observation.image")
|
|
259
|
+
and values are dictionaries of statistic types and their values. If provided,
|
|
260
|
+
these statistics will overwrite the default buffers.
|
|
261
|
+
"""
|
|
262
|
+
super().__init__()
|
|
263
|
+
self.features = features
|
|
264
|
+
self.norm_map = norm_map
|
|
265
|
+
self.stats = stats
|
|
266
|
+
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
|
267
|
+
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
|
268
|
+
for key, buffer in stats_buffers.items():
|
|
269
|
+
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
|
270
|
+
|
|
271
|
+
# TODO(rcadene): should we remove torch.no_grad?
|
|
272
|
+
@torch.no_grad
|
|
273
|
+
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
274
|
+
"""Unnormalizes the batch data.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
batch: Dictionary containing the data to unnormalize.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
dict[str, Tensor]: The unnormalized batch data.
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
ValueError: If an unknown normalization mode is encountered.
|
|
284
|
+
"""
|
|
285
|
+
warn_missing_keys(self.features, batch, "unnormalization")
|
|
286
|
+
batch = dict(batch) # shallow copy avoids mutating the input batch
|
|
287
|
+
for key, ft in self.features.items():
|
|
288
|
+
if key not in batch:
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
|
292
|
+
if norm_mode is NormalizationMode.IDENTITY:
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
if batch[key].numel() == 0: # skip empty tensors, which won't broadcast well
|
|
296
|
+
continue
|
|
297
|
+
|
|
298
|
+
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
|
299
|
+
|
|
300
|
+
if norm_mode is NormalizationMode.MEAN_STD:
|
|
301
|
+
mean = buffer["mean"]
|
|
302
|
+
std = buffer["std"]
|
|
303
|
+
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
|
304
|
+
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
|
305
|
+
batch[key] = batch[key] * (std + EPS) + mean
|
|
306
|
+
elif norm_mode is NormalizationMode.MIN_MAX:
|
|
307
|
+
min = buffer["min"]
|
|
308
|
+
max = buffer["max"]
|
|
309
|
+
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
|
310
|
+
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
|
311
|
+
batch[key] = (batch[key] + 1) / 2
|
|
312
|
+
batch[key] = batch[key] * (max - min + EPS) + min
|
|
313
|
+
else:
|
|
314
|
+
raise ValueError(norm_mode)
|
|
315
|
+
return batch
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""PI0 Policy Module.
|
|
15
|
+
|
|
16
|
+
This module implements the π0 (Pi0) Vision-Language-Action Flow Model policy,
|
|
17
|
+
designed for general robot control. It includes the policy definition,
|
|
18
|
+
configuration, and model architecture.
|
|
19
|
+
"""
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
Configuration module for the PI0 Policy.
|
|
18
|
+
|
|
19
|
+
This module defines the `PI0Config` class, which handles the configuration parameters
|
|
20
|
+
for the PI0 Vision-Language-Action Flow Model. It includes settings for the model architecture,
|
|
21
|
+
optimization, scheduling, and data processing.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
from typing import Literal
|
|
27
|
+
|
|
28
|
+
from opentau.configs.policies import PreTrainedConfig
|
|
29
|
+
from opentau.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
30
|
+
from opentau.optim.optimizers import AdamWConfig
|
|
31
|
+
from opentau.optim.schedulers import (
|
|
32
|
+
CosineDecayWithWarmupSchedulerConfig,
|
|
33
|
+
LRSchedulerConfig,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@PreTrainedConfig.register_subclass("pi0")
|
|
38
|
+
@dataclass
|
|
39
|
+
class PI0Config(PreTrainedConfig):
|
|
40
|
+
"""Configuration class for the PI0 Policy.
|
|
41
|
+
|
|
42
|
+
This class defines the configuration parameters for the PI0 model, including
|
|
43
|
+
input/output structure, model architecture, training settings, and preprocessing.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
n_obs_steps: Number of observation steps to use. Defaults to 1.
|
|
47
|
+
chunk_size: Size of the action chunk. The upper bound for n_action_steps. Defaults to 50.
|
|
48
|
+
n_action_steps: Number of action steps to predict. Defaults to 50.
|
|
49
|
+
safety_buffer: Safety buffer size. Defaults to 0.
|
|
50
|
+
normalization_mapping: Mapping of feature names to normalization modes.
|
|
51
|
+
Defaults to identity for visual features and mean-std for state and action.
|
|
52
|
+
max_state_dim: Maximum dimension for state vectors. Shorter vectors are padded. Defaults to 32.
|
|
53
|
+
max_action_dim: Maximum dimension for action vectors. Shorter vectors are padded. Defaults to 32.
|
|
54
|
+
resize_imgs_with_padding: Target size (height, width) for image resizing with padding.
|
|
55
|
+
Defaults to (224, 224).
|
|
56
|
+
empty_cameras: Number of empty camera inputs to add. Used for specific adaptations like
|
|
57
|
+
Aloha simulation. Defaults to 0.
|
|
58
|
+
tokenizer_max_length: Maximum length for tokenizer. Defaults to 48.
|
|
59
|
+
proj_width: Width of the projection layer. Defaults to 1024.
|
|
60
|
+
dropout: Dropout rate. Defaults to 0.1.
|
|
61
|
+
num_steps: Number of flow matching steps for decoding. Defaults to 10.
|
|
62
|
+
advantage_threshold: Advantage binarization threshold for AWR. Defaults to 0.0.
|
|
63
|
+
advantage: Advantage conditioning mode. One of "ignore", "on", "use".
|
|
64
|
+
"use" uses values from dataset, "ignore" disables conditioning,
|
|
65
|
+
"on" sets advantage to True (for expert demos). Defaults to "use".
|
|
66
|
+
init_strategy: Initialization strategy. One of "no_init", "full_he_init", "expert_only_he_init".
|
|
67
|
+
Defaults to "full_he_init".
|
|
68
|
+
use_cache: Whether to use KV cache during inference. Defaults to True.
|
|
69
|
+
attention_implementation: Attention implementation to use ("eager" or "fa2"). Defaults to "eager".
|
|
70
|
+
freeze_vision_encoder: Whether to freeze the vision encoder during fine-tuning. Defaults to True.
|
|
71
|
+
train_expert_only: Whether to train only the expert module. Defaults to False.
|
|
72
|
+
train_state_proj: Whether to train the state projection layer. Defaults to True.
|
|
73
|
+
optimizer_lr: Learning rate for the optimizer. Defaults to 2.5e-5.
|
|
74
|
+
optimizer_betas: Beta parameters for AdamW optimizer. Defaults to (0.9, 0.95).
|
|
75
|
+
optimizer_eps: Epsilon parameter for AdamW optimizer. Defaults to 1e-8.
|
|
76
|
+
optimizer_weight_decay: Weight decay for AdamW optimizer. Defaults to 1e-10.
|
|
77
|
+
scheduler_warmup_steps: Number of warmup steps for the scheduler. Defaults to 1_000.
|
|
78
|
+
scheduler_decay_steps: Number of decay steps for the scheduler. Defaults to 30_000.
|
|
79
|
+
scheduler_decay_lr: Target learning rate after decay. Defaults to 2.5e-6.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
# Input / output structure.
|
|
83
|
+
n_obs_steps: int = 1
|
|
84
|
+
chunk_size: int = 50
|
|
85
|
+
n_action_steps: int = 50
|
|
86
|
+
safety_buffer: int = 0
|
|
87
|
+
|
|
88
|
+
normalization_mapping: dict[str, NormalizationMode] = field(
|
|
89
|
+
default_factory=lambda: {
|
|
90
|
+
"VISUAL": NormalizationMode.IDENTITY,
|
|
91
|
+
"STATE": NormalizationMode.MEAN_STD,
|
|
92
|
+
"ACTION": NormalizationMode.MEAN_STD,
|
|
93
|
+
}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Shorter state and action vectors will be padded
|
|
97
|
+
max_state_dim: int = 32
|
|
98
|
+
max_action_dim: int = 32
|
|
99
|
+
|
|
100
|
+
# Image preprocessing
|
|
101
|
+
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
|
102
|
+
|
|
103
|
+
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
|
104
|
+
# left and right wrist cameras in addition to the top camera.
|
|
105
|
+
empty_cameras: int = 0
|
|
106
|
+
|
|
107
|
+
# Tokenizer
|
|
108
|
+
tokenizer_max_length: int = 48
|
|
109
|
+
|
|
110
|
+
# Projector
|
|
111
|
+
proj_width: int = 1024
|
|
112
|
+
|
|
113
|
+
# Dropout
|
|
114
|
+
dropout: float = 0.1
|
|
115
|
+
|
|
116
|
+
# Decoding
|
|
117
|
+
num_steps: int = 10
|
|
118
|
+
|
|
119
|
+
# advantage binarization threshold for AWR
|
|
120
|
+
advantage_threshold: float = 0.0
|
|
121
|
+
|
|
122
|
+
# When set to "use", the advantage values provided in the dataset will be used.
|
|
123
|
+
# When set to "ignore", no advantage conditioning will be applied.
|
|
124
|
+
# When set to "on", the advantage will always be True.
|
|
125
|
+
# This should only be "on" when training on expert demonstrations or interventions.
|
|
126
|
+
advantage: Literal["ignore", "on", "use"] = "use"
|
|
127
|
+
|
|
128
|
+
# Initialization strategy
|
|
129
|
+
init_strategy: Literal["no_init", "full_he_init", "expert_only_he_init"] = "full_he_init"
|
|
130
|
+
|
|
131
|
+
# Attention utils
|
|
132
|
+
use_cache: bool = True
|
|
133
|
+
attention_implementation: str = "eager" # or fa2
|
|
134
|
+
|
|
135
|
+
# Finetuning settings
|
|
136
|
+
freeze_vision_encoder: bool = True
|
|
137
|
+
train_expert_only: bool = False
|
|
138
|
+
train_state_proj: bool = True
|
|
139
|
+
|
|
140
|
+
# Training presets
|
|
141
|
+
optimizer_lr: float = 2.5e-5
|
|
142
|
+
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
|
143
|
+
optimizer_eps: float = 1e-8
|
|
144
|
+
optimizer_weight_decay: float = 1e-10
|
|
145
|
+
|
|
146
|
+
scheduler_warmup_steps: int = 1_000
|
|
147
|
+
scheduler_decay_steps: int = 30_000
|
|
148
|
+
scheduler_decay_lr: float = 2.5e-6
|
|
149
|
+
|
|
150
|
+
# TODO: Add EMA
|
|
151
|
+
|
|
152
|
+
def __post_init__(self):
|
|
153
|
+
"""Post-initialization validation."""
|
|
154
|
+
super().__post_init__()
|
|
155
|
+
|
|
156
|
+
# TODO(Steven): Validate device and amp? in all policy configs?
|
|
157
|
+
"""Input validation (not exhaustive)."""
|
|
158
|
+
if self.n_action_steps > self.chunk_size:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
|
161
|
+
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
|
162
|
+
)
|
|
163
|
+
if self.n_obs_steps != 1:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
assert self.init_strategy in ["no_init", "full_he_init", "expert_only_he_init"], (
|
|
169
|
+
f"Invalid init strategy: {self.init_strategy} must be one of ['no_init', 'full_he_init', 'expert_only_he_init']"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if self.init_strategy == "expert_only_he_init" and self.pretrained_path == "lerobot/pi0":
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"You cannot load pretrained PI0 model when init_strategy is 'expert_only_he_init' due to differences in PaliGemma tokenizer vocab sizes."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if self.pretrained_path is not None and self.pretrained_path != "lerobot/pi0":
|
|
178
|
+
logging.info("Setting init_strategy to 'no_init' because we are resuming from a checkpoint.")
|
|
179
|
+
self.init_strategy = "no_init"
|
|
180
|
+
|
|
181
|
+
def validate_features(self) -> None:
|
|
182
|
+
"""Validates the features and adds empty cameras if configured.
|
|
183
|
+
|
|
184
|
+
This method checks feature configurations and dynamically adds empty camera inputs
|
|
185
|
+
to `self.input_features` based on the `empty_cameras` parameter.
|
|
186
|
+
"""
|
|
187
|
+
# TODO: implement value error
|
|
188
|
+
# if not self.image_features and not self.env_state_feature:
|
|
189
|
+
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
|
190
|
+
|
|
191
|
+
for i in range(self.empty_cameras):
|
|
192
|
+
key = f"observation.images.empty_camera_{i}"
|
|
193
|
+
empty_camera = PolicyFeature(
|
|
194
|
+
type=FeatureType.VISUAL,
|
|
195
|
+
shape=(3, 480, 640),
|
|
196
|
+
)
|
|
197
|
+
self.input_features[key] = empty_camera
|
|
198
|
+
|
|
199
|
+
def get_optimizer_preset(self) -> AdamWConfig:
|
|
200
|
+
"""Returns the default optimizer configuration.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
AdamWConfig: The optimizer configuration with default parameters.
|
|
204
|
+
"""
|
|
205
|
+
return AdamWConfig(
|
|
206
|
+
lr=self.optimizer_lr,
|
|
207
|
+
betas=self.optimizer_betas,
|
|
208
|
+
eps=self.optimizer_eps,
|
|
209
|
+
weight_decay=self.optimizer_weight_decay,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def get_scheduler_preset(self) -> LRSchedulerConfig:
|
|
213
|
+
"""Returns the default scheduler configuration.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
CosineDecayWithWarmupSchedulerConfig: The scheduler configuration with default parameters.
|
|
217
|
+
"""
|
|
218
|
+
return CosineDecayWithWarmupSchedulerConfig(
|
|
219
|
+
peak_lr=self.optimizer_lr,
|
|
220
|
+
decay_lr=self.scheduler_decay_lr,
|
|
221
|
+
num_warmup_steps=self.scheduler_warmup_steps,
|
|
222
|
+
num_decay_steps=self.scheduler_decay_steps,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def observation_delta_indices(self) -> None:
|
|
227
|
+
"""Indices for observation deltas.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
None: As observation deltas are not used.
|
|
231
|
+
"""
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def action_delta_indices(self) -> list[int]:
|
|
236
|
+
"""Indices for action deltas.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
list[int]: A list of indices corresponding to the chunk size.
|
|
240
|
+
"""
|
|
241
|
+
return list(range(self.chunk_size))
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def reward_delta_indices(self) -> None:
|
|
245
|
+
"""Indices for reward deltas.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
None: As reward deltas are not used.
|
|
249
|
+
"""
|
|
250
|
+
return None
|