opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,994 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""π0: A Vision-Language-Action Flow Model for General Robot Control
|
|
19
|
+
|
|
20
|
+
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import math
|
|
24
|
+
from collections import deque
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn.functional as F # noqa: N812
|
|
28
|
+
from torch import Tensor, nn
|
|
29
|
+
from transformers import AutoTokenizer
|
|
30
|
+
|
|
31
|
+
from opentau.policies.normalize import Normalize, Unnormalize
|
|
32
|
+
from opentau.policies.pi0.configuration_pi0 import PI0Config
|
|
33
|
+
from opentau.policies.pi0.paligemma_with_expert import (
|
|
34
|
+
PaliGemmaWithExpertConfig,
|
|
35
|
+
PaliGemmaWithExpertModel,
|
|
36
|
+
)
|
|
37
|
+
from opentau.policies.pretrained import PreTrainedPolicy
|
|
38
|
+
from opentau.policies.utils import log_model_loading_keys
|
|
39
|
+
from opentau.utils.utils import get_safe_dtype
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def create_sinusoidal_pos_embedding(
|
|
43
|
+
time: Tensor, dimension: int, min_period: float, max_period: float, device: torch.device | str = "cpu"
|
|
44
|
+
) -> Tensor:
|
|
45
|
+
"""Computes sine-cosine positional embedding vectors for scalar positions.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
time: A 1-D tensor of shape (batch_size,).
|
|
49
|
+
dimension: The dimension of the embedding vectors. Must be divisible by 2.
|
|
50
|
+
min_period: The minimum period of the sinusoidal functions.
|
|
51
|
+
max_period: The maximum period of the sinusoidal functions.
|
|
52
|
+
device: The device to create the tensors on. Defaults to "cpu".
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
A tensor of shape (batch_size, dimension) containing the positional embeddings.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If dimension is not divisible by 2 or if time tensor is not 1-D.
|
|
59
|
+
"""
|
|
60
|
+
if dimension % 2 != 0:
|
|
61
|
+
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
|
62
|
+
|
|
63
|
+
if time.ndim != 1:
|
|
64
|
+
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
|
65
|
+
|
|
66
|
+
dtype = (
|
|
67
|
+
get_safe_dtype(torch.float64, device.type)
|
|
68
|
+
if isinstance(device, torch.device)
|
|
69
|
+
else get_safe_dtype(torch.float64, device)
|
|
70
|
+
)
|
|
71
|
+
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
|
72
|
+
period = min_period * (max_period / min_period) ** fraction
|
|
73
|
+
|
|
74
|
+
# Compute the outer product
|
|
75
|
+
scaling_factor = 1.0 / period * 2 * math.pi
|
|
76
|
+
sin_input = scaling_factor[None, :] * time[:, None]
|
|
77
|
+
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
|
78
|
+
return pos_emb
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def make_att_2d_masks(pad_masks: Tensor, att_masks: Tensor) -> Tensor:
|
|
82
|
+
"""Creates a 2-D attention mask given padding and 1-D attention masks.
|
|
83
|
+
|
|
84
|
+
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
|
85
|
+
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
|
86
|
+
setup several types of attention, for example:
|
|
87
|
+
|
|
88
|
+
[[1 1 1 1 1 1]]: pure causal attention.
|
|
89
|
+
|
|
90
|
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
|
91
|
+
themselves and the last 3 tokens have a causal attention. The first
|
|
92
|
+
entry could also be a 1 without changing behaviour.
|
|
93
|
+
|
|
94
|
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
|
95
|
+
block can attend all previous blocks and all tokens on the same block.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
pad_masks: bool[B, N] true if its part of the input, false if padding.
|
|
99
|
+
att_masks: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
|
100
|
+
it and 0 where it shares the same attention mask as the previous token.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
A 2D attention mask tensor of shape (B, N, N).
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If att_masks or pad_masks are not 2D.
|
|
107
|
+
"""
|
|
108
|
+
if att_masks.ndim != 2:
|
|
109
|
+
raise ValueError(att_masks.ndim)
|
|
110
|
+
if pad_masks.ndim != 2:
|
|
111
|
+
raise ValueError(pad_masks.ndim)
|
|
112
|
+
|
|
113
|
+
cumsum = torch.cumsum(att_masks, dim=1)
|
|
114
|
+
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
115
|
+
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
116
|
+
att_2d_masks = att_2d_masks & pad_2d_masks
|
|
117
|
+
return att_2d_masks
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def resize_with_pad(img: Tensor, width: int, height: int, pad_value: int = -1) -> Tensor:
|
|
121
|
+
"""Resizes an image to fit within the specified dimensions while maintaining aspect ratio,
|
|
122
|
+
and pads the remaining area with the specified value.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
img: Input image tensor of shape (batch_size, channels, current_height, current_width).
|
|
126
|
+
width: Target width.
|
|
127
|
+
height: Target height.
|
|
128
|
+
pad_value: Value to use for padding. Defaults to -1.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
The resized and padded image tensor of shape (batch_size, channels, height, width).
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
ValueError: If the input image tensor does not have 4 dimensions.
|
|
135
|
+
"""
|
|
136
|
+
# assume no-op when width height fits already
|
|
137
|
+
if img.ndim != 4:
|
|
138
|
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
|
139
|
+
|
|
140
|
+
cur_height, cur_width = img.shape[2:]
|
|
141
|
+
|
|
142
|
+
ratio = max(cur_width / width, cur_height / height)
|
|
143
|
+
resized_height = int(cur_height / ratio)
|
|
144
|
+
resized_width = int(cur_width / ratio)
|
|
145
|
+
resized_img = F.interpolate(
|
|
146
|
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
pad_height = max(0, int(height - resized_height))
|
|
150
|
+
pad_width = max(0, int(width - resized_width))
|
|
151
|
+
|
|
152
|
+
# pad on left and top of image
|
|
153
|
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
|
154
|
+
return padded_img
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
|
|
158
|
+
"""Pads the last dimension of a vector to a new size with zeros.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
vector: Input tensor. Can be (batch_size x sequence_length x features_dimension)
|
|
162
|
+
or (batch_size x features_dimension).
|
|
163
|
+
new_dim: The new size for the last dimension.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
The padded tensor.
|
|
167
|
+
"""
|
|
168
|
+
if vector.shape[-1] == new_dim:
|
|
169
|
+
return vector
|
|
170
|
+
shape = list(vector.shape)
|
|
171
|
+
current_dim = shape[-1]
|
|
172
|
+
shape[-1] = new_dim
|
|
173
|
+
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
|
174
|
+
new_vector[..., :current_dim] = vector
|
|
175
|
+
return new_vector
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class PI0Policy(PreTrainedPolicy):
|
|
179
|
+
"""Wrapper class around PI0FlowMatching model to train and run inference within OpenTau."""
|
|
180
|
+
|
|
181
|
+
config_class = PI0Config
|
|
182
|
+
name = "pi0"
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
config: PI0Config,
|
|
187
|
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
|
188
|
+
):
|
|
189
|
+
"""Initializes the PI0Policy.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
config: Policy configuration class instance.
|
|
193
|
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
|
194
|
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
super().__init__(config)
|
|
198
|
+
config.validate_features()
|
|
199
|
+
self.config = config
|
|
200
|
+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
|
201
|
+
self.normalize_targets = Normalize(
|
|
202
|
+
config.output_features, config.normalization_mapping, dataset_stats
|
|
203
|
+
)
|
|
204
|
+
self.unnormalize_outputs = Unnormalize(
|
|
205
|
+
config.output_features, config.normalization_mapping, dataset_stats
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
|
209
|
+
self.model = PI0FlowMatching(config)
|
|
210
|
+
|
|
211
|
+
self.reset()
|
|
212
|
+
|
|
213
|
+
def reset(self) -> None:
|
|
214
|
+
"""This should be called whenever the environment is reset."""
|
|
215
|
+
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
|
219
|
+
"""
|
|
220
|
+
Transform state dict keys to match expected model structure.
|
|
221
|
+
|
|
222
|
+
Transformations:
|
|
223
|
+
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
|
224
|
+
model.paligemma_with_expert.paligemma.lm_head
|
|
225
|
+
- model.paligemma_with_expert.paligemma.language_model.model ->
|
|
226
|
+
model.paligemma_with_expert.paligemma.model.language_model
|
|
227
|
+
- model.paligemma_with_expert.paligemma.vision_tower ->
|
|
228
|
+
model.paligemma_with_expert.paligemma.model.vision_tower
|
|
229
|
+
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
|
230
|
+
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
|
231
|
+
|
|
232
|
+
Also handles tied weights between lm_head.weight and
|
|
233
|
+
embed_tokens.weight.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
state_dict: The state dictionary to transform.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
The transformed state dictionary.
|
|
240
|
+
"""
|
|
241
|
+
import re
|
|
242
|
+
|
|
243
|
+
transformed_dict = {}
|
|
244
|
+
|
|
245
|
+
transformations = [
|
|
246
|
+
(
|
|
247
|
+
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
|
248
|
+
".paligemma_with_expert.paligemma.lm_head",
|
|
249
|
+
),
|
|
250
|
+
(
|
|
251
|
+
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
|
252
|
+
".paligemma_with_expert.paligemma.model.language_model",
|
|
253
|
+
),
|
|
254
|
+
(
|
|
255
|
+
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
|
256
|
+
".paligemma_with_expert.paligemma.model.vision_tower",
|
|
257
|
+
),
|
|
258
|
+
(
|
|
259
|
+
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
|
260
|
+
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
|
261
|
+
),
|
|
262
|
+
]
|
|
263
|
+
|
|
264
|
+
for key, value in state_dict.items():
|
|
265
|
+
new_key = key
|
|
266
|
+
for pattern, replacement in transformations:
|
|
267
|
+
new_key = pattern.sub(replacement, new_key)
|
|
268
|
+
transformed_dict[new_key] = value
|
|
269
|
+
|
|
270
|
+
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
|
271
|
+
lm_head_key = None
|
|
272
|
+
embed_tokens_key = None
|
|
273
|
+
|
|
274
|
+
for key in transformed_dict:
|
|
275
|
+
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
|
276
|
+
lm_head_key = key
|
|
277
|
+
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
|
278
|
+
embed_tokens_key = key
|
|
279
|
+
if lm_head_key and embed_tokens_key:
|
|
280
|
+
break
|
|
281
|
+
|
|
282
|
+
if lm_head_key and not embed_tokens_key:
|
|
283
|
+
embed_tokens_key = lm_head_key.replace(
|
|
284
|
+
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
|
285
|
+
)
|
|
286
|
+
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
|
287
|
+
elif embed_tokens_key and not lm_head_key:
|
|
288
|
+
lm_head_key = embed_tokens_key.replace(
|
|
289
|
+
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
|
290
|
+
)
|
|
291
|
+
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
|
292
|
+
|
|
293
|
+
return transformed_dict
|
|
294
|
+
|
|
295
|
+
@classmethod
|
|
296
|
+
def _load_as_safetensor(
|
|
297
|
+
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
|
298
|
+
) -> "PI0Policy":
|
|
299
|
+
"""Override to apply key transformations before loading.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
model: The model instance.
|
|
303
|
+
model_file: Path to the model file.
|
|
304
|
+
map_location: Device mapping location.
|
|
305
|
+
strict: Whether to strictly enforce state dict matching.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
The loaded model instance.
|
|
309
|
+
"""
|
|
310
|
+
from safetensors.torch import load_file
|
|
311
|
+
|
|
312
|
+
# Load the state dict from file safely
|
|
313
|
+
state_dict = load_file(model_file, device=map_location)
|
|
314
|
+
|
|
315
|
+
# Apply key transformations
|
|
316
|
+
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
|
317
|
+
|
|
318
|
+
# Apply tiling of linear input weights if needed
|
|
319
|
+
model._tile_linear_input_weight(transformed_state_dict)
|
|
320
|
+
|
|
321
|
+
# Load the transformed state dict
|
|
322
|
+
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
|
323
|
+
|
|
324
|
+
# Log message
|
|
325
|
+
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
|
326
|
+
return model
|
|
327
|
+
|
|
328
|
+
def get_optim_params(self) -> dict:
|
|
329
|
+
"""Returns the parameters to be optimized.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
A generator over the model parameters.
|
|
333
|
+
"""
|
|
334
|
+
return self.parameters()
|
|
335
|
+
|
|
336
|
+
@classmethod
|
|
337
|
+
def from_pretrained(cls, *args, **kwargs):
|
|
338
|
+
"""Override the from_pretrained method to display important disclaimer.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
*args: Positional arguments passed to super().from_pretrained.
|
|
342
|
+
**kwargs: Keyword arguments passed to super().from_pretrained.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
The loaded model instance.
|
|
346
|
+
"""
|
|
347
|
+
print(
|
|
348
|
+
"⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n"
|
|
349
|
+
" It is not expected to perform as well as the original implementation. \n"
|
|
350
|
+
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
|
351
|
+
)
|
|
352
|
+
return super().from_pretrained(*args, **kwargs)
|
|
353
|
+
|
|
354
|
+
@torch.no_grad()
|
|
355
|
+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
|
356
|
+
"""Predict a chunk of actions given environment observations.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
batch: Batch of data containing environment observations.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
The predicted action chunk.
|
|
363
|
+
|
|
364
|
+
Raises:
|
|
365
|
+
NotImplementedError: Always, as this method is not implemented for PI0.
|
|
366
|
+
"""
|
|
367
|
+
raise NotImplementedError("Currently not implemented for PI0")
|
|
368
|
+
|
|
369
|
+
@torch.no_grad()
|
|
370
|
+
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
|
371
|
+
"""Select a single action given environment observations.
|
|
372
|
+
|
|
373
|
+
This method wraps `select_actions` in order to return one action at a time for execution in the
|
|
374
|
+
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
|
375
|
+
queue is empty.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
batch: Batch of data containing environment observations.
|
|
379
|
+
noise: Optional noise tensor to be used during sampling.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
The selected action tensor.
|
|
383
|
+
"""
|
|
384
|
+
self.eval()
|
|
385
|
+
|
|
386
|
+
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
|
387
|
+
# querying the policy.
|
|
388
|
+
if len(self._action_queue) <= self.config.safety_buffer:
|
|
389
|
+
actions = self.sample_actions(batch, noise=noise)
|
|
390
|
+
self._action_queue.extend(actions)
|
|
391
|
+
return self._action_queue.popleft()
|
|
392
|
+
|
|
393
|
+
@torch.no_grad()
|
|
394
|
+
def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
|
395
|
+
"""Sample actions from the policy given environment observations.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
batch: Batch of data containing environment observations.
|
|
399
|
+
noise: Optional noise tensor.
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
The sampled actions tensor of shape (batch_size, action_dim).
|
|
403
|
+
"""
|
|
404
|
+
batch = self.normalize_inputs(batch)
|
|
405
|
+
|
|
406
|
+
images, img_masks = self.prepare_images(batch)
|
|
407
|
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
|
408
|
+
|
|
409
|
+
state = batch["state"]
|
|
410
|
+
actions = self.model.sample_actions(
|
|
411
|
+
images,
|
|
412
|
+
img_masks,
|
|
413
|
+
lang_tokens,
|
|
414
|
+
lang_masks,
|
|
415
|
+
state,
|
|
416
|
+
noise=noise,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Unpad actions
|
|
420
|
+
original_action_dim = self.config.action_feature.shape[0]
|
|
421
|
+
actions = actions[:, :, :original_action_dim]
|
|
422
|
+
|
|
423
|
+
actions = self.unnormalize_outputs({"actions": actions})["actions"]
|
|
424
|
+
|
|
425
|
+
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
|
426
|
+
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
|
427
|
+
actions = actions.transpose(0, 1)
|
|
428
|
+
return actions
|
|
429
|
+
|
|
430
|
+
def forward(
|
|
431
|
+
self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None
|
|
432
|
+
) -> dict[str, Tensor]:
|
|
433
|
+
"""Do a full training forward pass to compute the loss.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
batch: Batch of data containing environment observations, actions, and targets.
|
|
437
|
+
noise: Optional noise tensor.
|
|
438
|
+
time: Optional time tensor.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
A dictionary containing the loss components ("MSE" and "CE").
|
|
442
|
+
"""
|
|
443
|
+
batch = self.normalize_inputs(batch)
|
|
444
|
+
batch = self.normalize_targets(batch)
|
|
445
|
+
|
|
446
|
+
images, img_masks = self.prepare_images(batch)
|
|
447
|
+
state = batch["state"]
|
|
448
|
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
|
449
|
+
actions = batch["actions"]
|
|
450
|
+
actions_is_pad = batch.get("action_is_pad")
|
|
451
|
+
|
|
452
|
+
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
|
453
|
+
|
|
454
|
+
if actions_is_pad is not None:
|
|
455
|
+
in_episode_bound = ~actions_is_pad
|
|
456
|
+
losses = losses * in_episode_bound.unsqueeze(-1)
|
|
457
|
+
|
|
458
|
+
# Remove padding
|
|
459
|
+
losses = losses[:, :, : self.config.max_action_dim]
|
|
460
|
+
|
|
461
|
+
# For backward pass
|
|
462
|
+
loss = losses.mean()
|
|
463
|
+
|
|
464
|
+
return {"MSE": loss, "CE": torch.zeros_like(loss, requires_grad=True)}
|
|
465
|
+
|
|
466
|
+
def prepare_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
|
467
|
+
"""Apply Pi0 preprocessing to the images.
|
|
468
|
+
|
|
469
|
+
Resizes to 224x224 and padding to keep aspect ratio, and converts pixel range
|
|
470
|
+
from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
batch: Batch of data containing image tensors.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
A tuple containing:
|
|
477
|
+
- images: A list of processed image tensors.
|
|
478
|
+
- img_masks: A list of image mask tensors.
|
|
479
|
+
|
|
480
|
+
Raises:
|
|
481
|
+
ValueError: If no image features are present in the batch.
|
|
482
|
+
"""
|
|
483
|
+
images = []
|
|
484
|
+
img_masks = []
|
|
485
|
+
|
|
486
|
+
present_img_keys = [key for key in self.config.image_features if key in batch]
|
|
487
|
+
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
|
488
|
+
|
|
489
|
+
if len(present_img_keys) == 0:
|
|
490
|
+
raise ValueError(
|
|
491
|
+
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Preprocess image features present in the batch
|
|
495
|
+
for key in present_img_keys:
|
|
496
|
+
img = batch[key]
|
|
497
|
+
|
|
498
|
+
if self.config.resize_imgs_with_padding is not None:
|
|
499
|
+
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
|
500
|
+
|
|
501
|
+
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
|
502
|
+
img = img * 2.0 - 1.0
|
|
503
|
+
|
|
504
|
+
bsize = img.shape[0]
|
|
505
|
+
device = img.device
|
|
506
|
+
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
|
507
|
+
images.append(img)
|
|
508
|
+
img_masks.append(mask)
|
|
509
|
+
|
|
510
|
+
# Create image features not present in the batch
|
|
511
|
+
# as fully 0 padded images.
|
|
512
|
+
for num_empty_cameras in range(len(missing_img_keys)):
|
|
513
|
+
if num_empty_cameras >= self.config.empty_cameras:
|
|
514
|
+
break
|
|
515
|
+
img = torch.ones_like(img) * -1
|
|
516
|
+
mask = torch.zeros_like(mask)
|
|
517
|
+
images.append(img)
|
|
518
|
+
img_masks.append(mask)
|
|
519
|
+
|
|
520
|
+
return images, img_masks
|
|
521
|
+
|
|
522
|
+
def prepare_language(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
523
|
+
"""Tokenize the text input.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
batch: Batch of data containing "prompt" and potentially "advantage".
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
A tuple containing:
|
|
530
|
+
- lang_tokens: Tensor of language tokens.
|
|
531
|
+
- lang_masks: Tensor of language attention masks.
|
|
532
|
+
"""
|
|
533
|
+
device = batch["state"].device
|
|
534
|
+
tasks = batch["prompt"]
|
|
535
|
+
|
|
536
|
+
# PaliGemma prompt has to end with a new line
|
|
537
|
+
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
|
538
|
+
|
|
539
|
+
for idx, task in enumerate(tasks):
|
|
540
|
+
if self.config.advantage == "on": # always add positive advantage
|
|
541
|
+
tasks[idx] = f"{task}Advantage: positive\n"
|
|
542
|
+
elif self.config.advantage == "use": # add advantage based on threshold
|
|
543
|
+
adv = batch["advantage"][idx] >= self.config.advantage_threshold
|
|
544
|
+
adv = "positive" if adv else "negative"
|
|
545
|
+
tasks[idx] = f"{task}Advantage: {adv}\n"
|
|
546
|
+
|
|
547
|
+
tokenized_prompt = self.language_tokenizer.__call__(
|
|
548
|
+
tasks,
|
|
549
|
+
padding="max_length",
|
|
550
|
+
padding_side="right",
|
|
551
|
+
max_length=self.config.tokenizer_max_length,
|
|
552
|
+
return_tensors="pt",
|
|
553
|
+
)
|
|
554
|
+
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
|
555
|
+
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
|
556
|
+
|
|
557
|
+
return lang_tokens, lang_masks
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class PI0FlowMatching(nn.Module):
|
|
561
|
+
"""
|
|
562
|
+
π0: A Vision-Language-Action Flow Model for General Robot Control
|
|
563
|
+
|
|
564
|
+
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
|
|
565
|
+
|
|
566
|
+
┌──────────────────────────────┐
|
|
567
|
+
│ actions │
|
|
568
|
+
│ ▲ │
|
|
569
|
+
│ ┌┴─────┐ │
|
|
570
|
+
│ kv cache │Gemma │ │
|
|
571
|
+
│ ┌──────────►│Expert│ │
|
|
572
|
+
│ │ │ │ │
|
|
573
|
+
│ ┌┴────────┐ │x 10 │ │
|
|
574
|
+
│ │ │ └▲──▲──┘ │
|
|
575
|
+
│ │PaliGemma│ │ │ │
|
|
576
|
+
│ │ │ │ robot state │
|
|
577
|
+
│ │ │ noise │
|
|
578
|
+
│ └▲──▲─────┘ │
|
|
579
|
+
│ │ │ │
|
|
580
|
+
│ │ image(s) │
|
|
581
|
+
│ language tokens │
|
|
582
|
+
└──────────────────────────────┘
|
|
583
|
+
"""
|
|
584
|
+
|
|
585
|
+
def __init__(self, config: PI0Config):
|
|
586
|
+
"""Initializes the PI0FlowMatching model.
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
config: Model configuration.
|
|
590
|
+
"""
|
|
591
|
+
super().__init__()
|
|
592
|
+
self.config = config
|
|
593
|
+
|
|
594
|
+
load_pretrained_paligemma = (
|
|
595
|
+
self.config.init_strategy == "expert_only_he_init"
|
|
596
|
+
) # only load pretrained paligemma if we are He-initializing the expert only
|
|
597
|
+
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
|
598
|
+
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
|
599
|
+
train_expert_only=self.config.train_expert_only,
|
|
600
|
+
attention_implementation=self.config.attention_implementation,
|
|
601
|
+
load_pretrained_paligemma=load_pretrained_paligemma,
|
|
602
|
+
dropout=self.config.dropout,
|
|
603
|
+
)
|
|
604
|
+
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
|
605
|
+
|
|
606
|
+
# Projections are float32
|
|
607
|
+
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
|
|
608
|
+
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
|
609
|
+
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
|
610
|
+
|
|
611
|
+
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
|
|
612
|
+
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
|
613
|
+
|
|
614
|
+
self.set_requires_grad()
|
|
615
|
+
|
|
616
|
+
self._init_model()
|
|
617
|
+
|
|
618
|
+
def set_requires_grad(self) -> None:
|
|
619
|
+
"""Sets the requires_grad attribute for state projection parameters."""
|
|
620
|
+
for params in self.state_proj.parameters():
|
|
621
|
+
params.requires_grad = self.config.train_state_proj
|
|
622
|
+
|
|
623
|
+
def _init_weights(self, module: nn.Module) -> None:
|
|
624
|
+
"""Initialize weights using He (Kaiming) initialization.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
module: The module to initialize.
|
|
628
|
+
"""
|
|
629
|
+
if isinstance(module, nn.Linear):
|
|
630
|
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
|
631
|
+
if module.bias is not None:
|
|
632
|
+
nn.init.zeros_(module.bias)
|
|
633
|
+
elif isinstance(module, nn.LayerNorm):
|
|
634
|
+
nn.init.ones_(module.weight)
|
|
635
|
+
nn.init.zeros_(module.bias)
|
|
636
|
+
|
|
637
|
+
def _init_model(self) -> None:
|
|
638
|
+
"""Initialize the model weights based on the configuration."""
|
|
639
|
+
if self.config.init_strategy == "no_init":
|
|
640
|
+
return
|
|
641
|
+
elif self.config.init_strategy == "full_he_init":
|
|
642
|
+
for m in self.modules():
|
|
643
|
+
self._init_weights(m)
|
|
644
|
+
elif self.config.init_strategy == "expert_only_he_init":
|
|
645
|
+
for m in self.paligemma_with_expert.gemma_expert.modules():
|
|
646
|
+
self._init_weights(m)
|
|
647
|
+
else:
|
|
648
|
+
raise ValueError(f"Invalid init strategy: {self.config.init_strategy}")
|
|
649
|
+
|
|
650
|
+
def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor:
|
|
651
|
+
"""Samples Gaussian noise.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
shape: The shape of the noise tensor.
|
|
655
|
+
device: The device to create the tensor on.
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
A tensor containing the sampled noise.
|
|
659
|
+
"""
|
|
660
|
+
noise = torch.normal(
|
|
661
|
+
mean=0.0,
|
|
662
|
+
std=1.0,
|
|
663
|
+
size=shape,
|
|
664
|
+
dtype=torch.float32,
|
|
665
|
+
device=device,
|
|
666
|
+
)
|
|
667
|
+
return noise
|
|
668
|
+
|
|
669
|
+
def sample_time(self, bsize: int, device: torch.device | str) -> Tensor:
|
|
670
|
+
"""Samples time steps from a Beta distribution.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
bsize: Batch size.
|
|
674
|
+
device: The device to create the tensor on.
|
|
675
|
+
|
|
676
|
+
Returns:
|
|
677
|
+
A tensor containing the sampled time steps.
|
|
678
|
+
"""
|
|
679
|
+
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
|
680
|
+
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
|
681
|
+
time = time_beta * 0.999 + 0.001
|
|
682
|
+
return time
|
|
683
|
+
|
|
684
|
+
def embed_prefix(
|
|
685
|
+
self,
|
|
686
|
+
images: list[Tensor],
|
|
687
|
+
img_masks: list[Tensor],
|
|
688
|
+
lang_tokens: Tensor,
|
|
689
|
+
lang_masks: Tensor,
|
|
690
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
691
|
+
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
|
692
|
+
for PaliGemma transformer processing.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
images: List of image tensors.
|
|
696
|
+
img_masks: List of image mask tensors.
|
|
697
|
+
lang_tokens: Language token tensor.
|
|
698
|
+
lang_masks: Language mask tensor.
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
A tuple containing:
|
|
702
|
+
- embs: Concatenated embeddings tensor.
|
|
703
|
+
- pad_masks: Concatenated padding masks tensor.
|
|
704
|
+
- att_masks: Attention masks tensor.
|
|
705
|
+
"""
|
|
706
|
+
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
|
707
|
+
embs = []
|
|
708
|
+
pad_masks = []
|
|
709
|
+
att_masks = []
|
|
710
|
+
|
|
711
|
+
# TODO: remove for loop
|
|
712
|
+
for (
|
|
713
|
+
img,
|
|
714
|
+
img_mask,
|
|
715
|
+
) in zip(images, img_masks, strict=False):
|
|
716
|
+
img_emb = self.paligemma_with_expert.embed_image(img)
|
|
717
|
+
img_emb = img_emb.to(dtype=torch.bfloat16)
|
|
718
|
+
|
|
719
|
+
# image embeddings don't need to be unnormalized because `fix/lerobot_openpi` branch of huggingface
|
|
720
|
+
# already removed the normalization inside PaliGemma
|
|
721
|
+
pass
|
|
722
|
+
|
|
723
|
+
bsize, num_img_embs = img_emb.shape[:2]
|
|
724
|
+
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
|
725
|
+
|
|
726
|
+
embs.append(img_emb)
|
|
727
|
+
pad_masks.append(img_mask)
|
|
728
|
+
|
|
729
|
+
# Create attention masks so that image tokens attend to each other
|
|
730
|
+
att_masks += [0] * num_img_embs
|
|
731
|
+
|
|
732
|
+
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
|
733
|
+
|
|
734
|
+
# Normalize language embeddings
|
|
735
|
+
lang_emb_dim = lang_emb.shape[-1]
|
|
736
|
+
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
|
737
|
+
|
|
738
|
+
embs.append(lang_emb)
|
|
739
|
+
pad_masks.append(lang_masks)
|
|
740
|
+
|
|
741
|
+
# full attention between image and language inputs
|
|
742
|
+
num_lang_embs = lang_emb.shape[1]
|
|
743
|
+
att_masks += [0] * num_lang_embs
|
|
744
|
+
|
|
745
|
+
embs = torch.cat(embs, dim=1)
|
|
746
|
+
pad_masks = torch.cat(pad_masks, dim=1)
|
|
747
|
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
748
|
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
749
|
+
|
|
750
|
+
return embs, pad_masks, att_masks
|
|
751
|
+
|
|
752
|
+
def embed_suffix(
|
|
753
|
+
self, state: Tensor, noisy_actions: Tensor, timestep: Tensor
|
|
754
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
755
|
+
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
state: State tensor.
|
|
759
|
+
noisy_actions: Tensor containing noisy actions.
|
|
760
|
+
timestep: Tensor containing timesteps.
|
|
761
|
+
|
|
762
|
+
Returns:
|
|
763
|
+
A tuple containing:
|
|
764
|
+
- embs: Concatenated embeddings tensor.
|
|
765
|
+
- pad_masks: Concatenated padding masks tensor.
|
|
766
|
+
- att_masks: Attention masks tensor.
|
|
767
|
+
"""
|
|
768
|
+
embs = []
|
|
769
|
+
pad_masks = []
|
|
770
|
+
att_masks = []
|
|
771
|
+
|
|
772
|
+
# Embed state
|
|
773
|
+
state_emb = self.state_proj(state)
|
|
774
|
+
state_emb = state_emb.to(dtype=torch.bfloat16)
|
|
775
|
+
embs.append(state_emb[:, None, :])
|
|
776
|
+
bsize = state_emb.shape[0]
|
|
777
|
+
dtype = state_emb.dtype
|
|
778
|
+
device = state_emb.device
|
|
779
|
+
|
|
780
|
+
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
|
781
|
+
pad_masks.append(state_mask)
|
|
782
|
+
|
|
783
|
+
# Set attention masks so that image and language inputs do not attend to state or actions
|
|
784
|
+
att_masks += [1]
|
|
785
|
+
|
|
786
|
+
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
|
787
|
+
time_emb = create_sinusoidal_pos_embedding(
|
|
788
|
+
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
|
789
|
+
)
|
|
790
|
+
time_emb = time_emb.type(dtype=dtype)
|
|
791
|
+
|
|
792
|
+
# Fuse timestep + action information using an MLP
|
|
793
|
+
noisy_actions = noisy_actions.to(dtype=dtype)
|
|
794
|
+
action_emb = self.action_in_proj(noisy_actions)
|
|
795
|
+
|
|
796
|
+
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
|
797
|
+
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
|
798
|
+
|
|
799
|
+
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
|
800
|
+
action_time_emb = F.silu(action_time_emb) # swish == silu
|
|
801
|
+
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
|
802
|
+
|
|
803
|
+
# Add to input tokens
|
|
804
|
+
embs.append(action_time_emb)
|
|
805
|
+
|
|
806
|
+
bsize, action_time_dim = action_time_emb.shape[:2]
|
|
807
|
+
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
|
808
|
+
pad_masks.append(action_time_mask)
|
|
809
|
+
|
|
810
|
+
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
|
811
|
+
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
|
812
|
+
|
|
813
|
+
embs = torch.cat(embs, dim=1)
|
|
814
|
+
pad_masks = torch.cat(pad_masks, dim=1)
|
|
815
|
+
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
|
816
|
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
817
|
+
|
|
818
|
+
return embs, pad_masks, att_masks
|
|
819
|
+
|
|
820
|
+
def forward(
|
|
821
|
+
self,
|
|
822
|
+
images: list[Tensor],
|
|
823
|
+
img_masks: list[Tensor],
|
|
824
|
+
lang_tokens: Tensor,
|
|
825
|
+
lang_masks: Tensor,
|
|
826
|
+
state: Tensor,
|
|
827
|
+
actions: Tensor,
|
|
828
|
+
noise: Tensor | None = None,
|
|
829
|
+
time: Tensor | None = None,
|
|
830
|
+
) -> Tensor:
|
|
831
|
+
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors).
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
images: List of image tensors.
|
|
835
|
+
img_masks: List of image mask tensors.
|
|
836
|
+
lang_tokens: Language token tensor.
|
|
837
|
+
lang_masks: Language mask tensor.
|
|
838
|
+
state: State tensor.
|
|
839
|
+
actions: Action tensor.
|
|
840
|
+
noise: Optional noise tensor.
|
|
841
|
+
time: Optional time tensor.
|
|
842
|
+
|
|
843
|
+
Returns:
|
|
844
|
+
The computed loss tensor.
|
|
845
|
+
"""
|
|
846
|
+
if noise is None:
|
|
847
|
+
noise = self.sample_noise(actions.shape, actions.device)
|
|
848
|
+
|
|
849
|
+
if time is None:
|
|
850
|
+
time = self.sample_time(actions.shape[0], actions.device)
|
|
851
|
+
|
|
852
|
+
time_expanded = time[:, None, None]
|
|
853
|
+
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
|
854
|
+
u_t = noise - actions
|
|
855
|
+
|
|
856
|
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
857
|
+
images, img_masks, lang_tokens, lang_masks
|
|
858
|
+
)
|
|
859
|
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
|
|
860
|
+
|
|
861
|
+
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
|
862
|
+
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
|
863
|
+
|
|
864
|
+
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
|
865
|
+
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
|
866
|
+
|
|
867
|
+
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
|
868
|
+
attention_mask=att_2d_masks,
|
|
869
|
+
position_ids=position_ids,
|
|
870
|
+
past_key_values=None,
|
|
871
|
+
inputs_embeds=[prefix_embs, suffix_embs],
|
|
872
|
+
use_cache=False,
|
|
873
|
+
fill_kv_cache=False,
|
|
874
|
+
)
|
|
875
|
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
|
876
|
+
# Original openpi code, upcast attention output
|
|
877
|
+
v_t = self.action_out_proj(suffix_out)
|
|
878
|
+
v_t = v_t.to(dtype=torch.float32)
|
|
879
|
+
|
|
880
|
+
losses = F.mse_loss(u_t, v_t, reduction="none")
|
|
881
|
+
return losses
|
|
882
|
+
|
|
883
|
+
def sample_actions(
|
|
884
|
+
self,
|
|
885
|
+
images: list[Tensor],
|
|
886
|
+
img_masks: list[Tensor],
|
|
887
|
+
lang_tokens: Tensor,
|
|
888
|
+
lang_masks: Tensor,
|
|
889
|
+
state: Tensor,
|
|
890
|
+
noise: Tensor | None = None,
|
|
891
|
+
) -> Tensor:
|
|
892
|
+
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors).
|
|
893
|
+
|
|
894
|
+
Args:
|
|
895
|
+
images: List of image tensors.
|
|
896
|
+
img_masks: List of image mask tensors.
|
|
897
|
+
lang_tokens: Language token tensor.
|
|
898
|
+
lang_masks: Language mask tensor.
|
|
899
|
+
state: State tensor.
|
|
900
|
+
noise: Optional noise tensor.
|
|
901
|
+
|
|
902
|
+
Returns:
|
|
903
|
+
The sampled action tensor.
|
|
904
|
+
"""
|
|
905
|
+
bsize = state.shape[0]
|
|
906
|
+
device = state.device
|
|
907
|
+
|
|
908
|
+
if noise is None:
|
|
909
|
+
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
|
910
|
+
noise = self.sample_noise(actions_shape, device)
|
|
911
|
+
|
|
912
|
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
913
|
+
images, img_masks, lang_tokens, lang_masks
|
|
914
|
+
)
|
|
915
|
+
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
916
|
+
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
917
|
+
|
|
918
|
+
# Compute image and language key value cache
|
|
919
|
+
_, past_key_values = self.paligemma_with_expert.forward(
|
|
920
|
+
attention_mask=prefix_att_2d_masks,
|
|
921
|
+
position_ids=prefix_position_ids,
|
|
922
|
+
past_key_values=None,
|
|
923
|
+
inputs_embeds=[prefix_embs, None],
|
|
924
|
+
use_cache=self.config.use_cache,
|
|
925
|
+
fill_kv_cache=True,
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
dt = -1.0 / self.config.num_steps
|
|
929
|
+
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
930
|
+
|
|
931
|
+
x_t = noise
|
|
932
|
+
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
|
933
|
+
while time >= -dt / 2:
|
|
934
|
+
expanded_time = time.expand(bsize)
|
|
935
|
+
v_t = self.denoise_step(
|
|
936
|
+
state,
|
|
937
|
+
prefix_pad_masks,
|
|
938
|
+
past_key_values,
|
|
939
|
+
x_t,
|
|
940
|
+
expanded_time,
|
|
941
|
+
)
|
|
942
|
+
|
|
943
|
+
# Euler step
|
|
944
|
+
x_t += dt * v_t
|
|
945
|
+
time += dt
|
|
946
|
+
return x_t
|
|
947
|
+
|
|
948
|
+
def denoise_step(
|
|
949
|
+
self,
|
|
950
|
+
state: Tensor,
|
|
951
|
+
prefix_pad_masks: Tensor,
|
|
952
|
+
past_key_values: list[dict[str, Tensor]],
|
|
953
|
+
x_t: Tensor,
|
|
954
|
+
timestep: Tensor,
|
|
955
|
+
) -> Tensor:
|
|
956
|
+
"""Apply one denoising step of the noise `x_t` at a given timestep.
|
|
957
|
+
|
|
958
|
+
Args:
|
|
959
|
+
state: State tensor.
|
|
960
|
+
prefix_pad_masks: Prefix padding masks.
|
|
961
|
+
past_key_values: Past key values from the VLM.
|
|
962
|
+
x_t: Current noise tensor.
|
|
963
|
+
timestep: Current timestep.
|
|
964
|
+
|
|
965
|
+
Returns:
|
|
966
|
+
The predicted velocity tensor (v_t).
|
|
967
|
+
"""
|
|
968
|
+
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
|
|
969
|
+
|
|
970
|
+
suffix_len = suffix_pad_masks.shape[1]
|
|
971
|
+
batch_size = prefix_pad_masks.shape[0]
|
|
972
|
+
prefix_len = prefix_pad_masks.shape[1]
|
|
973
|
+
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
|
974
|
+
|
|
975
|
+
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
|
976
|
+
|
|
977
|
+
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
|
978
|
+
|
|
979
|
+
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
|
980
|
+
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
981
|
+
|
|
982
|
+
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
|
983
|
+
attention_mask=full_att_2d_masks,
|
|
984
|
+
position_ids=position_ids,
|
|
985
|
+
past_key_values=past_key_values,
|
|
986
|
+
inputs_embeds=[None, suffix_embs],
|
|
987
|
+
use_cache=self.config.use_cache,
|
|
988
|
+
fill_kv_cache=False,
|
|
989
|
+
)
|
|
990
|
+
suffix_out = outputs_embeds[1]
|
|
991
|
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
|
992
|
+
v_t = self.action_out_proj(suffix_out)
|
|
993
|
+
v_t = v_t.to(dtype=torch.float32)
|
|
994
|
+
return v_t
|