opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1257 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""π05: A Vision-Language-Action Flow Model for General Robot Control
|
|
19
|
+
|
|
20
|
+
[Paper](https://www.physicalintelligence.company/download/pi05.pdf)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import builtins
|
|
24
|
+
import logging
|
|
25
|
+
import math
|
|
26
|
+
from collections import deque
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn.functional as F # noqa: N812
|
|
32
|
+
from einops import rearrange
|
|
33
|
+
from torch import Tensor, nn
|
|
34
|
+
from transformers import AutoProcessor, AutoTokenizer
|
|
35
|
+
|
|
36
|
+
from opentau.configs.policies import PreTrainedConfig
|
|
37
|
+
from opentau.configs.types import NormalizationMode
|
|
38
|
+
from opentau.policies.normalize import Normalize, Unnormalize
|
|
39
|
+
from opentau.policies.pi05.configuration_pi05 import PI05Config
|
|
40
|
+
from opentau.policies.pi05.paligemma_with_expert import (
|
|
41
|
+
PaliGemmaWithExpertConfig,
|
|
42
|
+
PaliGemmaWithExpertModel,
|
|
43
|
+
)
|
|
44
|
+
from opentau.policies.pretrained import PreTrainedPolicy, T
|
|
45
|
+
from opentau.utils.utils import get_safe_dtype
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def create_sinusoidal_pos_embedding(
|
|
49
|
+
time: Tensor, dimension: int, min_period: float, max_period: float, device: torch.device | str = "cpu"
|
|
50
|
+
) -> Tensor:
|
|
51
|
+
"""Computes sine-cosine positional embedding vectors for scalar positions.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
time: A 1-D tensor of shape (batch_size,).
|
|
55
|
+
dimension: The dimension of the embedding vectors. Must be divisible by 2.
|
|
56
|
+
min_period: The minimum period of the sinusoidal functions.
|
|
57
|
+
max_period: The maximum period of the sinusoidal functions.
|
|
58
|
+
device: The device to create the tensors on. Defaults to "cpu".
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
A tensor of shape (batch_size, dimension) containing the positional embeddings.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If dimension is not divisible by 2 or if time tensor is not 1-D.
|
|
65
|
+
"""
|
|
66
|
+
if dimension % 2 != 0:
|
|
67
|
+
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
|
68
|
+
|
|
69
|
+
if time.ndim != 1:
|
|
70
|
+
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
|
71
|
+
|
|
72
|
+
dtype = (
|
|
73
|
+
get_safe_dtype(torch.float64, device.type)
|
|
74
|
+
if isinstance(device, torch.device)
|
|
75
|
+
else get_safe_dtype(torch.float64, device)
|
|
76
|
+
)
|
|
77
|
+
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
|
78
|
+
period = min_period * (max_period / min_period) ** fraction
|
|
79
|
+
|
|
80
|
+
# Compute the outer product
|
|
81
|
+
scaling_factor = 1.0 / period * 2 * math.pi
|
|
82
|
+
sin_input = scaling_factor[None, :] * time[:, None]
|
|
83
|
+
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
|
84
|
+
return pos_emb
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def make_att_2d_masks(
|
|
88
|
+
pad_masks: Tensor,
|
|
89
|
+
att_masks: Tensor,
|
|
90
|
+
n_cross_att_tokens: int | None = None,
|
|
91
|
+
cross_att_pad_masks: Tensor | None = None,
|
|
92
|
+
) -> Tensor:
|
|
93
|
+
"""Creates a 2-D attention mask given padding and 1-D attention masks.
|
|
94
|
+
|
|
95
|
+
Tokens can attend to valid inputs tokens which have a cumulative `att_masks`
|
|
96
|
+
smaller or equal to theirs. This way `att_masks` int[B, N] can be used to
|
|
97
|
+
setup several types of attention, for example:
|
|
98
|
+
|
|
99
|
+
[[1 1 1 1 1 1]]: pure causal attention.
|
|
100
|
+
|
|
101
|
+
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
|
102
|
+
themselves and the last 3 tokens have a causal attention. The first
|
|
103
|
+
entry could also be a 1 without changing behaviour.
|
|
104
|
+
|
|
105
|
+
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
|
106
|
+
block can attend all previous blocks and all tokens on the same block.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
pad_masks: bool[B, N] true if its part of the input, false if padding.
|
|
110
|
+
att_masks: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
|
111
|
+
it and 0 where it shares the same attention mask as the previous token.
|
|
112
|
+
n_cross_att_tokens: Add attention mask for cross-attention tokens if
|
|
113
|
+
`n_cross_att_tokens` is provided.
|
|
114
|
+
cross_att_pad_masks: Padding masks for cross attention tokens. Required if
|
|
115
|
+
`n_cross_att_tokens` is provided.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
A 2D attention mask tensor of shape (B, N + n_cross_att_tokens, N + n_cross_att_tokens)
|
|
119
|
+
if n_cross_att_tokens is provided, else (B, N, N).
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: If att_masks or pad_masks are not 2D (including batch dimension).
|
|
123
|
+
AssertionError: If cross_att_pad_masks is missing when n_cross_att_tokens is set,
|
|
124
|
+
or if its shape is incorrect.
|
|
125
|
+
"""
|
|
126
|
+
if att_masks.ndim != 2:
|
|
127
|
+
raise ValueError(att_masks.ndim)
|
|
128
|
+
if pad_masks.ndim != 2:
|
|
129
|
+
raise ValueError(pad_masks.ndim)
|
|
130
|
+
|
|
131
|
+
cumsum = torch.cumsum(att_masks, dim=1)
|
|
132
|
+
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
133
|
+
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
134
|
+
att_2d_masks = att_2d_masks & pad_2d_masks
|
|
135
|
+
|
|
136
|
+
# If `n_cross_att_tokens` is provided, we add a mask for cross-attention tokens at the end of the sequence.
|
|
137
|
+
if n_cross_att_tokens is not None:
|
|
138
|
+
assert cross_att_pad_masks is not None, (
|
|
139
|
+
"cross_att_pad_masks must be provided if n_cross_att_tokens is provided"
|
|
140
|
+
)
|
|
141
|
+
assert cross_att_pad_masks.shape == (att_masks.size(0), n_cross_att_tokens), (
|
|
142
|
+
"cross_att_pad_masks must have shape (batch_size, n_cross_att_tokens)"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
cross_att_mask = torch.full(
|
|
146
|
+
(att_masks.size(0), att_masks.size(1), n_cross_att_tokens),
|
|
147
|
+
True,
|
|
148
|
+
dtype=torch.bool,
|
|
149
|
+
device=att_masks.device,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Apply padding masks: pad_masks for rows, cross_att_pad_masks for columns
|
|
153
|
+
cross_att_mask = cross_att_mask & pad_masks[:, :, None] & cross_att_pad_masks[:, None, :]
|
|
154
|
+
|
|
155
|
+
att_2d_masks = torch.cat((att_2d_masks, cross_att_mask), dim=2)
|
|
156
|
+
|
|
157
|
+
return att_2d_masks
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def resize_with_pad(img: Tensor, width: int, height: int, pad_value: int = -1) -> Tensor:
|
|
161
|
+
"""Resizes an image to fit within the specified dimensions while maintaining aspect ratio,
|
|
162
|
+
and pads the remaining area with the specified value.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
img: Input image tensor of shape (batch_size, channels, current_height, current_width).
|
|
166
|
+
width: Target width.
|
|
167
|
+
height: Target height.
|
|
168
|
+
pad_value: Value to use for padding. Defaults to -1.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
The resized and padded image tensor of shape (batch_size, channels, height, width).
|
|
172
|
+
|
|
173
|
+
Raises:
|
|
174
|
+
ValueError: If the input image tensor does not have 4 dimensions.
|
|
175
|
+
"""
|
|
176
|
+
# assume no-op when width height fits already
|
|
177
|
+
if img.ndim != 4:
|
|
178
|
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
|
179
|
+
|
|
180
|
+
cur_height, cur_width = img.shape[2:]
|
|
181
|
+
|
|
182
|
+
ratio = max(cur_width / width, cur_height / height)
|
|
183
|
+
resized_height = int(cur_height / ratio)
|
|
184
|
+
resized_width = int(cur_width / ratio)
|
|
185
|
+
resized_img = F.interpolate(
|
|
186
|
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
pad_height = max(0, int(height - resized_height))
|
|
190
|
+
pad_width = max(0, int(width - resized_width))
|
|
191
|
+
|
|
192
|
+
# pad on left and top of image
|
|
193
|
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
|
194
|
+
return padded_img
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
|
|
198
|
+
"""Pads the last dimension of a vector to a new size with zeros.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
vector: Input tensor. Can be (batch_size x sequence_length x features_dimension)
|
|
202
|
+
or (batch_size x features_dimension).
|
|
203
|
+
new_dim: The new size for the last dimension.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The padded tensor.
|
|
207
|
+
"""
|
|
208
|
+
if vector.shape[-1] == new_dim:
|
|
209
|
+
return vector
|
|
210
|
+
shape = list(vector.shape)
|
|
211
|
+
current_dim = shape[-1]
|
|
212
|
+
shape[-1] = new_dim
|
|
213
|
+
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
|
214
|
+
new_vector[..., :current_dim] = vector
|
|
215
|
+
return new_vector
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def pad_discrete_tokens(tokens: list[list[int]], max_length: int) -> tuple[np.ndarray, np.ndarray]:
|
|
219
|
+
"""Pads or truncates a list of discrete action token sequences to a fixed length.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
tokens: A list of discrete action token sequences (lists of integers).
|
|
223
|
+
max_length: The target length for the discrete action token sequences.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
A tuple containing:
|
|
227
|
+
- discrete_action_tokens: A numpy array of shape (len(tokens), max_length) containing the padded discrete action tokens.
|
|
228
|
+
- discrete_action_masks: A boolean numpy array of shape (len(tokens), max_length) indicating valid discrete action tokens (True) and padding (False).
|
|
229
|
+
"""
|
|
230
|
+
discrete_action_tokens = []
|
|
231
|
+
discrete_action_masks = []
|
|
232
|
+
for token in tokens:
|
|
233
|
+
if len(token) > max_length:
|
|
234
|
+
discrete_action_tokens.append(np.array(token[:max_length]))
|
|
235
|
+
discrete_action_masks.append(np.ones(max_length, dtype=bool))
|
|
236
|
+
else:
|
|
237
|
+
discrete_action_masks.append(
|
|
238
|
+
np.concatenate(
|
|
239
|
+
[np.ones(len(token), dtype=bool), np.zeros(max_length - len(token), dtype=bool)]
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
discrete_action_tokens.append(np.pad(token, (0, max_length - len(token)), constant_values=0))
|
|
243
|
+
return np.array(discrete_action_tokens), np.array(discrete_action_masks)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class PI05Policy(PreTrainedPolicy):
|
|
247
|
+
"""Wrapper class around PI05FlowMatching model to train and run inference within OpenTau."""
|
|
248
|
+
|
|
249
|
+
config_class = PI05Config
|
|
250
|
+
name = "pi05"
|
|
251
|
+
|
|
252
|
+
def __init__(
|
|
253
|
+
self,
|
|
254
|
+
config: PI05Config,
|
|
255
|
+
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
|
256
|
+
):
|
|
257
|
+
"""Initializes the PI05Policy.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
config: Policy configuration class instance or None, in which case the default instantiation of
|
|
261
|
+
the configuration class is used.
|
|
262
|
+
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
|
263
|
+
that they will be passed with a call to `load_state_dict` before the policy is used.
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
super().__init__(config)
|
|
267
|
+
config.validate_features()
|
|
268
|
+
self.config = config
|
|
269
|
+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
|
270
|
+
self.normalize_targets = Normalize(
|
|
271
|
+
config.output_features, config.normalization_mapping, dataset_stats
|
|
272
|
+
)
|
|
273
|
+
self.normalize_actions = Normalize(
|
|
274
|
+
config.output_features, {"ACTION": NormalizationMode.MIN_MAX}, dataset_stats
|
|
275
|
+
)
|
|
276
|
+
self.unnormalize_outputs = Unnormalize(
|
|
277
|
+
config.output_features, config.normalization_mapping, dataset_stats
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
|
281
|
+
|
|
282
|
+
self.discrete_action_processor = AutoProcessor.from_pretrained(
|
|
283
|
+
"physical-intelligence/fast", trust_remote_code=True
|
|
284
|
+
)
|
|
285
|
+
# Get vocab size from processor
|
|
286
|
+
discrete_action_vocab_size = getattr(self.discrete_action_processor, "vocab_size", None)
|
|
287
|
+
self.model = PI05FlowMatching(config, discrete_action_vocab_size=discrete_action_vocab_size)
|
|
288
|
+
|
|
289
|
+
self.reset()
|
|
290
|
+
|
|
291
|
+
def reset(self) -> None:
|
|
292
|
+
"""This should be called whenever the environment is reset."""
|
|
293
|
+
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
|
294
|
+
|
|
295
|
+
@classmethod
|
|
296
|
+
def from_pretrained(
|
|
297
|
+
cls: builtins.type[T],
|
|
298
|
+
pretrained_name_or_path: str | Path,
|
|
299
|
+
*,
|
|
300
|
+
config: PreTrainedConfig | None = None,
|
|
301
|
+
force_download: bool = False,
|
|
302
|
+
resume_download: bool | None = None,
|
|
303
|
+
proxies: dict | None = None,
|
|
304
|
+
token: str | bool | None = None,
|
|
305
|
+
cache_dir: str | Path | None = None,
|
|
306
|
+
local_files_only: bool = False,
|
|
307
|
+
revision: str | None = None,
|
|
308
|
+
strict: bool = True,
|
|
309
|
+
**kwargs,
|
|
310
|
+
) -> T:
|
|
311
|
+
"""Override the from_pretrained method to handle key remapping.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
pretrained_name_or_path: Path to the pretrained model or its name on the Hub.
|
|
315
|
+
config: Configuration object.
|
|
316
|
+
force_download: Whether to force download the model weights.
|
|
317
|
+
resume_download: Whether to resume download.
|
|
318
|
+
proxies: Proxy configuration.
|
|
319
|
+
token: Authentication token.
|
|
320
|
+
cache_dir: Directory to cache downloaded files.
|
|
321
|
+
local_files_only: Whether to only look for files locally.
|
|
322
|
+
revision: Specific model revision.
|
|
323
|
+
strict: Whether to strictly enforce state dict matching.
|
|
324
|
+
**kwargs: Additional keyword arguments.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
The loaded model instance.
|
|
328
|
+
|
|
329
|
+
Raises:
|
|
330
|
+
ValueError: If pretrained_name_or_path is None.
|
|
331
|
+
"""
|
|
332
|
+
if pretrained_name_or_path is None:
|
|
333
|
+
raise ValueError("pretrained_name_or_path is required")
|
|
334
|
+
|
|
335
|
+
# Use provided config if available, otherwise create default config
|
|
336
|
+
if config is None:
|
|
337
|
+
config = PreTrainedConfig.from_pretrained(
|
|
338
|
+
pretrained_name_or_path=pretrained_name_or_path,
|
|
339
|
+
force_download=force_download,
|
|
340
|
+
resume_download=resume_download,
|
|
341
|
+
proxies=proxies,
|
|
342
|
+
token=token,
|
|
343
|
+
cache_dir=cache_dir,
|
|
344
|
+
local_files_only=local_files_only,
|
|
345
|
+
revision=revision,
|
|
346
|
+
**kwargs,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Initialize model without loading weights
|
|
350
|
+
# Check if dataset_stats were provided in kwargs
|
|
351
|
+
model = cls(config, **kwargs)
|
|
352
|
+
|
|
353
|
+
# Now manually load and remap the state dict
|
|
354
|
+
try:
|
|
355
|
+
# Try to load the pytorch_model.bin or model.safetensors file
|
|
356
|
+
print(f"Loading model from: {pretrained_name_or_path}")
|
|
357
|
+
try:
|
|
358
|
+
from transformers.utils import cached_file
|
|
359
|
+
|
|
360
|
+
# Try safetensors first
|
|
361
|
+
resolved_file = cached_file(
|
|
362
|
+
pretrained_name_or_path,
|
|
363
|
+
"model.safetensors",
|
|
364
|
+
cache_dir=kwargs.get("cache_dir"),
|
|
365
|
+
force_download=kwargs.get("force_download", False),
|
|
366
|
+
resume_download=kwargs.get("resume_download"),
|
|
367
|
+
proxies=kwargs.get("proxies"),
|
|
368
|
+
use_auth_token=kwargs.get("use_auth_token"),
|
|
369
|
+
revision=kwargs.get("revision"),
|
|
370
|
+
local_files_only=kwargs.get("local_files_only", False),
|
|
371
|
+
)
|
|
372
|
+
from safetensors.torch import load_file
|
|
373
|
+
|
|
374
|
+
original_state_dict = load_file(resolved_file)
|
|
375
|
+
print("✓ Loaded state dict from model.safetensors")
|
|
376
|
+
except Exception as e:
|
|
377
|
+
print(f"Could not load state dict from remote files: {e}")
|
|
378
|
+
print("Returning model without loading pretrained weights")
|
|
379
|
+
return model
|
|
380
|
+
|
|
381
|
+
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
|
382
|
+
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
|
383
|
+
|
|
384
|
+
# Then add "model." prefix for all keys that don't already have it
|
|
385
|
+
remapped_state_dict = {}
|
|
386
|
+
remap_count = 0
|
|
387
|
+
|
|
388
|
+
for key, value in fixed_state_dict.items():
|
|
389
|
+
if not key.startswith("model.") and "normalize" not in key:
|
|
390
|
+
new_key = f"model.{key}"
|
|
391
|
+
remapped_state_dict[new_key] = value
|
|
392
|
+
remap_count += 1
|
|
393
|
+
if remap_count <= 10: # Only print first 10 to avoid spam
|
|
394
|
+
print(f"Remapped: {key} -> {new_key}")
|
|
395
|
+
else:
|
|
396
|
+
remapped_state_dict[key] = value
|
|
397
|
+
|
|
398
|
+
if remap_count > 0:
|
|
399
|
+
print(f"Remapped {remap_count} state dict keys")
|
|
400
|
+
|
|
401
|
+
# Load the remapped state dict into the model
|
|
402
|
+
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
|
|
403
|
+
|
|
404
|
+
if missing_keys:
|
|
405
|
+
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
|
406
|
+
if len(missing_keys) <= 20:
|
|
407
|
+
for key in missing_keys:
|
|
408
|
+
print(f" - {key}")
|
|
409
|
+
else:
|
|
410
|
+
for key in missing_keys[:20]:
|
|
411
|
+
print(f" - {key}")
|
|
412
|
+
print(f" ... and {len(missing_keys) - 20} more")
|
|
413
|
+
|
|
414
|
+
if unexpected_keys:
|
|
415
|
+
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
|
416
|
+
if len(unexpected_keys) <= 20:
|
|
417
|
+
for key in unexpected_keys:
|
|
418
|
+
print(f" - {key}")
|
|
419
|
+
else:
|
|
420
|
+
for key in unexpected_keys[:20]:
|
|
421
|
+
print(f" - {key}")
|
|
422
|
+
print(f" ... and {len(unexpected_keys) - 20} more")
|
|
423
|
+
|
|
424
|
+
if not missing_keys and not unexpected_keys:
|
|
425
|
+
print("All keys loaded successfully!")
|
|
426
|
+
|
|
427
|
+
except Exception as e:
|
|
428
|
+
print(f"Warning: Could not remap state dict keys: {e}")
|
|
429
|
+
|
|
430
|
+
return model
|
|
431
|
+
|
|
432
|
+
def _fix_pytorch_state_dict_keys(
|
|
433
|
+
self, state_dict: dict[str, Tensor], model_config: PreTrainedConfig
|
|
434
|
+
) -> dict[str, Tensor]: # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
|
|
435
|
+
"""Fix state dict keys to match current model architecture.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
state_dict: The state dictionary to fix.
|
|
439
|
+
model_config: The model configuration.
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
The fixed state dictionary.
|
|
443
|
+
"""
|
|
444
|
+
import re
|
|
445
|
+
|
|
446
|
+
fixed_state_dict = {}
|
|
447
|
+
|
|
448
|
+
for key, value in state_dict.items():
|
|
449
|
+
new_key = key
|
|
450
|
+
|
|
451
|
+
# Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
|
|
452
|
+
# For gemma expert layers
|
|
453
|
+
if re.match(
|
|
454
|
+
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
|
|
455
|
+
key,
|
|
456
|
+
):
|
|
457
|
+
# Check if the model actually has adaRMS enabled for the expert
|
|
458
|
+
expert_uses_adarms = getattr(
|
|
459
|
+
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
|
|
460
|
+
)
|
|
461
|
+
if expert_uses_adarms:
|
|
462
|
+
logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
|
|
463
|
+
continue
|
|
464
|
+
|
|
465
|
+
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
|
|
466
|
+
# Check if the model actually has adaRMS enabled for the expert
|
|
467
|
+
expert_uses_adarms = getattr(
|
|
468
|
+
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
|
|
469
|
+
)
|
|
470
|
+
if expert_uses_adarms:
|
|
471
|
+
logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
# Handle MLP naming changes for pi05
|
|
475
|
+
# pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
|
|
476
|
+
if key.startswith("action_time_mlp_in."):
|
|
477
|
+
new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
|
|
478
|
+
elif key.startswith("action_time_mlp_out."):
|
|
479
|
+
new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
|
|
480
|
+
# Also handle state_proj which shouldn't exist in pi05
|
|
481
|
+
if key.startswith("state_proj."):
|
|
482
|
+
logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
|
|
483
|
+
continue
|
|
484
|
+
|
|
485
|
+
# Handle vision tower embedding layer potential differences
|
|
486
|
+
if "patch_embedding" in key:
|
|
487
|
+
# Some checkpoints might have this, but current model expects different structure
|
|
488
|
+
logging.warning(f"Vision embedding key might need handling: {key}")
|
|
489
|
+
|
|
490
|
+
fixed_state_dict[new_key] = value
|
|
491
|
+
|
|
492
|
+
return fixed_state_dict
|
|
493
|
+
|
|
494
|
+
def get_optim_params(self) -> dict:
|
|
495
|
+
"""Returns the parameters to be optimized.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
A generator over the model parameters.
|
|
499
|
+
"""
|
|
500
|
+
return self.parameters()
|
|
501
|
+
|
|
502
|
+
@torch.no_grad()
|
|
503
|
+
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
|
504
|
+
"""Predict a chunk of actions given environment observations.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
batch: Batch of data containing environment observations.
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
The predicted action chunk.
|
|
511
|
+
|
|
512
|
+
Raises:
|
|
513
|
+
NotImplementedError: Always, as this method is not implemented for PI05.
|
|
514
|
+
"""
|
|
515
|
+
raise NotImplementedError("Currently not implemented for PI05")
|
|
516
|
+
|
|
517
|
+
@torch.no_grad()
|
|
518
|
+
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
|
519
|
+
"""Select a single action given environment observations.
|
|
520
|
+
|
|
521
|
+
This method wraps `select_actions` in order to return one action at a time for execution in the
|
|
522
|
+
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
|
523
|
+
queue is empty.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
batch: Batch of data containing environment observations.
|
|
527
|
+
noise: Optional noise tensor to be used during sampling.
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
The selected action tensor.
|
|
531
|
+
"""
|
|
532
|
+
self.eval()
|
|
533
|
+
|
|
534
|
+
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
|
535
|
+
# querying the policy.
|
|
536
|
+
if len(self._action_queue) == 0:
|
|
537
|
+
actions = self.sample_actions(batch, noise=noise)
|
|
538
|
+
self._action_queue.extend(actions)
|
|
539
|
+
return self._action_queue.popleft()
|
|
540
|
+
|
|
541
|
+
@torch.no_grad()
|
|
542
|
+
def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
|
543
|
+
"""Sample actions from the policy given environment observations.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
batch: Batch of data containing environment observations.
|
|
547
|
+
noise: Optional noise tensor.
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
The sampled actions tensor of shape (batch_size, action_dim).
|
|
551
|
+
"""
|
|
552
|
+
batch = self.normalize_inputs(batch)
|
|
553
|
+
|
|
554
|
+
images, img_masks = self.prepare_images(batch)
|
|
555
|
+
lang_tokens, lang_masks = self.prepare_language(batch)
|
|
556
|
+
|
|
557
|
+
actions = self.model.sample_actions(
|
|
558
|
+
images,
|
|
559
|
+
img_masks,
|
|
560
|
+
lang_tokens,
|
|
561
|
+
lang_masks,
|
|
562
|
+
noise=noise,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Unpad actions
|
|
566
|
+
original_action_dim = self.config.action_feature.shape[0]
|
|
567
|
+
actions = actions[:, :, :original_action_dim]
|
|
568
|
+
|
|
569
|
+
actions = self.unnormalize_outputs({"actions": actions})["actions"]
|
|
570
|
+
|
|
571
|
+
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
|
572
|
+
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
|
573
|
+
actions = actions.transpose(0, 1)
|
|
574
|
+
return actions
|
|
575
|
+
|
|
576
|
+
def forward(
|
|
577
|
+
self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None
|
|
578
|
+
) -> dict[str, Tensor]:
|
|
579
|
+
"""Do a full training forward pass to compute the loss.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
batch: Batch of data containing environment observations, actions, and targets.
|
|
583
|
+
noise: Optional noise tensor.
|
|
584
|
+
time: Optional time tensor.
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
A dictionary containing the loss components ("MSE" and "CE").
|
|
588
|
+
"""
|
|
589
|
+
batch = self.normalize_inputs(batch)
|
|
590
|
+
batch["discrete_actions"] = self.normalize_actions(dict(batch))["actions"]
|
|
591
|
+
batch = self.normalize_targets(batch)
|
|
592
|
+
|
|
593
|
+
images, img_masks = self.prepare_images(
|
|
594
|
+
batch
|
|
595
|
+
) # in img_masks we have True for real images and False for padded images
|
|
596
|
+
lang_tokens, lang_masks = self.prepare_language(
|
|
597
|
+
batch
|
|
598
|
+
) # in lang_masks we have True for real tokens and False for padded tokens
|
|
599
|
+
discrete_actions, discrete_action_masks = self.prepare_discrete_actions(
|
|
600
|
+
batch
|
|
601
|
+
) # in discrete_action_masks we have True for real tokens and False for padded tokens
|
|
602
|
+
actions = batch["actions"]
|
|
603
|
+
actions_is_pad = batch.get(
|
|
604
|
+
"action_is_pad"
|
|
605
|
+
) # in actions_is_pad we have False for real actions and True for padded actions
|
|
606
|
+
|
|
607
|
+
losses = self.model.forward(
|
|
608
|
+
images,
|
|
609
|
+
img_masks,
|
|
610
|
+
lang_tokens,
|
|
611
|
+
lang_masks,
|
|
612
|
+
actions,
|
|
613
|
+
noise,
|
|
614
|
+
time,
|
|
615
|
+
discrete_actions,
|
|
616
|
+
discrete_action_masks,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
mse_loss = losses["MSE"]
|
|
620
|
+
ce_loss = losses["CE"]
|
|
621
|
+
if actions_is_pad is not None:
|
|
622
|
+
in_episode_bound = ~actions_is_pad
|
|
623
|
+
mse_loss = mse_loss * in_episode_bound.unsqueeze(-1)
|
|
624
|
+
|
|
625
|
+
# Remove padding
|
|
626
|
+
mse_loss = mse_loss[:, :, : self.config.max_action_dim]
|
|
627
|
+
|
|
628
|
+
# For backward pass
|
|
629
|
+
loss = mse_loss.mean()
|
|
630
|
+
|
|
631
|
+
return {"MSE": loss, "CE": ce_loss}
|
|
632
|
+
|
|
633
|
+
def prepare_discrete_state(self, batch: dict[str, Tensor]) -> list[str]:
|
|
634
|
+
"""Discretizes the state into bins and converts it to a string representation.
|
|
635
|
+
|
|
636
|
+
Each dimension of the state vector is discretized into 256 bins.
|
|
637
|
+
The values of each dimension of the state are expected to be in the range [-1, 1].
|
|
638
|
+
The discretization bins are linearly spaced between -1 and 1.
|
|
639
|
+
The index of the bin for each dimension is then concatenated into a space-separated string.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
batch: Batch of data containing the "state" tensor.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
A list of strings, where each string is a space-separated list of discretized state values.
|
|
646
|
+
|
|
647
|
+
Raises:
|
|
648
|
+
ValueError: If the state values are not normalized between -1 and 1.
|
|
649
|
+
"""
|
|
650
|
+
state = batch["state"]
|
|
651
|
+
state_np = state.to(device="cpu", dtype=torch.float32).numpy()
|
|
652
|
+
if np.any(state_np < -1.0) or np.any(state_np > 1.0):
|
|
653
|
+
logging.warning(
|
|
654
|
+
f"State values are not normalized between -1 and 1. Min: {state_np.min()}, Max: {state_np.max()}"
|
|
655
|
+
)
|
|
656
|
+
state_np = np.clip(state_np, -1.0, 1.0)
|
|
657
|
+
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
|
658
|
+
return [
|
|
659
|
+
" ".join(map(str, row)) for row in discretized_states
|
|
660
|
+
] # TODO: return a tensor instead of a list of strings?
|
|
661
|
+
|
|
662
|
+
def prepare_discrete_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
663
|
+
"""Prepares discrete actions for the model by tokenizing and padding them.
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
batch: Batch of data containing the key "discrete_actions".
|
|
667
|
+
|
|
668
|
+
Returns:
|
|
669
|
+
A tuple containing:
|
|
670
|
+
- discrete_action_tokens: A tensor of shape (batch_size, max_length) containing the tokenized actions.
|
|
671
|
+
- discrete_action_masks: A tensor of shape (batch_size, max_length) indicating valid tokens.
|
|
672
|
+
"""
|
|
673
|
+
device = batch["discrete_actions"].device
|
|
674
|
+
discrete_actions = batch["discrete_actions"].to(device="cpu", dtype=torch.float32)
|
|
675
|
+
tokens = self.discrete_action_processor.__call__(discrete_actions)
|
|
676
|
+
discrete_action_tokens, discrete_action_masks = pad_discrete_tokens(
|
|
677
|
+
tokens, self.config.discrete_action_max_length
|
|
678
|
+
)
|
|
679
|
+
return torch.from_numpy(discrete_action_tokens).to(device=device, dtype=torch.long), torch.from_numpy(
|
|
680
|
+
discrete_action_masks
|
|
681
|
+
).to(device=device, dtype=torch.bool)
|
|
682
|
+
|
|
683
|
+
def prepare_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
|
684
|
+
"""Apply preprocessing to the images.
|
|
685
|
+
|
|
686
|
+
Resizes to 224x224 and padding to keep aspect ratio, and converts pixel range
|
|
687
|
+
from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
batch: Batch of data containing image tensors.
|
|
691
|
+
|
|
692
|
+
Returns:
|
|
693
|
+
A tuple containing:
|
|
694
|
+
- images: A list of processed image tensors.
|
|
695
|
+
- img_masks: A list of image mask tensors.
|
|
696
|
+
|
|
697
|
+
Raises:
|
|
698
|
+
ValueError: If no image features are present in the batch.
|
|
699
|
+
"""
|
|
700
|
+
images = []
|
|
701
|
+
img_masks = []
|
|
702
|
+
|
|
703
|
+
present_img_keys = [key for key in self.config.image_features if key in batch]
|
|
704
|
+
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
|
705
|
+
|
|
706
|
+
if len(present_img_keys) == 0:
|
|
707
|
+
raise ValueError(
|
|
708
|
+
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
# Preprocess image features present in the batch
|
|
712
|
+
for key in present_img_keys:
|
|
713
|
+
img = batch[key]
|
|
714
|
+
|
|
715
|
+
if self.config.resize_imgs_with_padding is not None:
|
|
716
|
+
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
|
717
|
+
|
|
718
|
+
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
|
719
|
+
img = img * 2.0 - 1.0
|
|
720
|
+
|
|
721
|
+
bsize = img.shape[0]
|
|
722
|
+
device = img.device
|
|
723
|
+
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
|
724
|
+
images.append(img)
|
|
725
|
+
img_masks.append(mask)
|
|
726
|
+
|
|
727
|
+
# Create image features not present in the batch
|
|
728
|
+
# as fully 0 padded images.
|
|
729
|
+
for num_empty_cameras in range(len(missing_img_keys)):
|
|
730
|
+
if num_empty_cameras >= self.config.empty_cameras:
|
|
731
|
+
break
|
|
732
|
+
img = torch.ones_like(img) * -1
|
|
733
|
+
mask = torch.zeros_like(mask)
|
|
734
|
+
images.append(img)
|
|
735
|
+
img_masks.append(mask)
|
|
736
|
+
|
|
737
|
+
return images, img_masks
|
|
738
|
+
|
|
739
|
+
def prepare_language(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
740
|
+
"""Tokenize the text input.
|
|
741
|
+
|
|
742
|
+
The state is already expected to be discretized into a space-separated string.
|
|
743
|
+
|
|
744
|
+
Args:
|
|
745
|
+
batch: Batch of data containing the key "prompt" and "state".
|
|
746
|
+
|
|
747
|
+
Returns:
|
|
748
|
+
A tuple containing:
|
|
749
|
+
- lang_tokens: Tensor of language tokens.
|
|
750
|
+
- lang_masks: Tensor of language attention masks.
|
|
751
|
+
"""
|
|
752
|
+
device = batch["state"].device
|
|
753
|
+
tasks = batch["prompt"]
|
|
754
|
+
|
|
755
|
+
# PaliGemma prompt has to end with a new line
|
|
756
|
+
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
|
757
|
+
|
|
758
|
+
# add state to the prompt
|
|
759
|
+
state = self.prepare_discrete_state(batch)
|
|
760
|
+
prompt = [f"Task: {task}State: {state}\nActions:" for task, state in zip(tasks, state, strict=False)]
|
|
761
|
+
|
|
762
|
+
tokenized_prompt = self.language_tokenizer.__call__(
|
|
763
|
+
prompt,
|
|
764
|
+
padding="max_length",
|
|
765
|
+
padding_side="right",
|
|
766
|
+
max_length=self.config.tokenizer_max_length,
|
|
767
|
+
return_tensors="pt",
|
|
768
|
+
truncation=True,
|
|
769
|
+
)
|
|
770
|
+
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
|
771
|
+
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
|
772
|
+
|
|
773
|
+
return lang_tokens, lang_masks
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
class PI05FlowMatching(nn.Module):
|
|
777
|
+
"""
|
|
778
|
+
π05: A Vision-Language-Action Flow Model for General Robot Control
|
|
779
|
+
|
|
780
|
+
[Paper](https://www.physicalintelligence.company/download/pi05.pdf)
|
|
781
|
+
|
|
782
|
+
┌──────────────────────────────────────────┐
|
|
783
|
+
│ actions │
|
|
784
|
+
│ ▲ │
|
|
785
|
+
│ ┌┴─────┐ │
|
|
786
|
+
│ kv cache │Gemma │ │
|
|
787
|
+
│ ┌──────────►│Expert│ │
|
|
788
|
+
│ │ │ │ │
|
|
789
|
+
│ ┌┴─────────┐ │x 10 │ │
|
|
790
|
+
│ │ │ └▲─────┘ │
|
|
791
|
+
│ │PaliGemma │ │ │
|
|
792
|
+
│ │ │ noise │
|
|
793
|
+
│ └▲──▲──▲──▲ │
|
|
794
|
+
│ │ │ │ └── discrete actions │
|
|
795
|
+
│ │ │ └───── robot state │
|
|
796
|
+
│ │ └──────── language tokens │
|
|
797
|
+
│ └─────────── image(s) │
|
|
798
|
+
└──────────────────────────────────────────┘
|
|
799
|
+
"""
|
|
800
|
+
|
|
801
|
+
def __init__(self, config: PI05Config, discrete_action_vocab_size: int | None = None):
|
|
802
|
+
"""Initializes the PI05FlowMatching model.
|
|
803
|
+
|
|
804
|
+
Args:
|
|
805
|
+
config: Model configuration.
|
|
806
|
+
discrete_action_vocab_size: Size of the discrete action vocabulary.
|
|
807
|
+
"""
|
|
808
|
+
super().__init__()
|
|
809
|
+
self.config = config
|
|
810
|
+
|
|
811
|
+
load_pretrained_paligemma = (
|
|
812
|
+
self.config.init_strategy == "expert_only_he_init"
|
|
813
|
+
) # only load pretrained paligemma if we are He-initializing the expert only
|
|
814
|
+
paligemma_with_export_config = PaliGemmaWithExpertConfig(
|
|
815
|
+
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
|
816
|
+
train_expert_only=self.config.train_expert_only,
|
|
817
|
+
attention_implementation=self.config.attention_implementation,
|
|
818
|
+
load_pretrained_paligemma=load_pretrained_paligemma,
|
|
819
|
+
discrete_action_vocab_size=discrete_action_vocab_size,
|
|
820
|
+
dropout=self.config.dropout,
|
|
821
|
+
)
|
|
822
|
+
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
|
|
823
|
+
|
|
824
|
+
# Projections are float32
|
|
825
|
+
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
|
|
826
|
+
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
|
|
827
|
+
|
|
828
|
+
self.time_mlp_in = nn.Linear(self.config.proj_width, self.config.proj_width)
|
|
829
|
+
self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
|
|
830
|
+
|
|
831
|
+
self._init_model()
|
|
832
|
+
|
|
833
|
+
def _init_weights(self, module: nn.Module) -> None:
|
|
834
|
+
"""Initialize weights using He (Kaiming) initialization.
|
|
835
|
+
|
|
836
|
+
Args:
|
|
837
|
+
module: The module to initialize.
|
|
838
|
+
"""
|
|
839
|
+
if isinstance(module, nn.Linear):
|
|
840
|
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
|
841
|
+
if module.bias is not None:
|
|
842
|
+
nn.init.zeros_(module.bias)
|
|
843
|
+
elif isinstance(module, nn.LayerNorm):
|
|
844
|
+
nn.init.ones_(module.weight)
|
|
845
|
+
nn.init.zeros_(module.bias)
|
|
846
|
+
|
|
847
|
+
def _init_model(self) -> None:
|
|
848
|
+
"""Initialize the model weights based on the configuration."""
|
|
849
|
+
if self.config.init_strategy == "no_init":
|
|
850
|
+
return
|
|
851
|
+
elif self.config.init_strategy == "full_he_init":
|
|
852
|
+
for m in self.modules():
|
|
853
|
+
self._init_weights(m)
|
|
854
|
+
elif self.config.init_strategy == "expert_only_he_init":
|
|
855
|
+
for m in self.paligemma_with_expert.gemma_expert.modules():
|
|
856
|
+
self._init_weights(m)
|
|
857
|
+
else:
|
|
858
|
+
raise ValueError(f"Invalid init strategy: {self.config.init_strategy}")
|
|
859
|
+
|
|
860
|
+
def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor:
|
|
861
|
+
"""Samples Gaussian noise.
|
|
862
|
+
|
|
863
|
+
Args:
|
|
864
|
+
shape: The shape of the noise tensor.
|
|
865
|
+
device: The device to create the tensor on.
|
|
866
|
+
|
|
867
|
+
Returns:
|
|
868
|
+
A tensor containing the sampled noise.
|
|
869
|
+
"""
|
|
870
|
+
noise = torch.normal(
|
|
871
|
+
mean=0.0,
|
|
872
|
+
std=1.0,
|
|
873
|
+
size=shape,
|
|
874
|
+
dtype=torch.float32,
|
|
875
|
+
device=device,
|
|
876
|
+
)
|
|
877
|
+
return noise
|
|
878
|
+
|
|
879
|
+
def sample_time(self, bsize: int, device: torch.device | str) -> Tensor:
|
|
880
|
+
"""Samples time steps from a Beta distribution.
|
|
881
|
+
|
|
882
|
+
Args:
|
|
883
|
+
bsize: Batch size.
|
|
884
|
+
device: The device to create the tensor on.
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
A tensor containing the sampled time steps.
|
|
888
|
+
"""
|
|
889
|
+
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
|
890
|
+
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
|
891
|
+
time = time_beta * 0.999 + 0.001
|
|
892
|
+
return time
|
|
893
|
+
|
|
894
|
+
def embed_prefix(
|
|
895
|
+
self,
|
|
896
|
+
images: list[Tensor],
|
|
897
|
+
img_masks: list[Tensor],
|
|
898
|
+
lang_tokens: Tensor,
|
|
899
|
+
lang_masks: Tensor,
|
|
900
|
+
discrete_actions: Tensor | None = None,
|
|
901
|
+
discrete_action_masks: Tensor | None = None,
|
|
902
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
903
|
+
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
|
904
|
+
for PaliGemma transformer processing.
|
|
905
|
+
|
|
906
|
+
Args:
|
|
907
|
+
images: List of image tensors.
|
|
908
|
+
img_masks: List of image mask tensors.
|
|
909
|
+
lang_tokens: Language token tensor.
|
|
910
|
+
lang_masks: Language mask tensor.
|
|
911
|
+
discrete_actions: Optional discrete action tensor.
|
|
912
|
+
discrete_action_masks: Optional discrete action mask tensor.
|
|
913
|
+
|
|
914
|
+
Returns:
|
|
915
|
+
A tuple containing:
|
|
916
|
+
- embs: Concatenated embeddings tensor.
|
|
917
|
+
- pad_masks: Concatenated padding masks tensor.
|
|
918
|
+
- att_masks: Attention masks tensor.
|
|
919
|
+
"""
|
|
920
|
+
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
|
921
|
+
embs = []
|
|
922
|
+
pad_masks = []
|
|
923
|
+
att_masks = []
|
|
924
|
+
|
|
925
|
+
# TODO: remove for loop
|
|
926
|
+
for (
|
|
927
|
+
img,
|
|
928
|
+
img_mask,
|
|
929
|
+
) in zip(images, img_masks, strict=False):
|
|
930
|
+
img_emb = self.paligemma_with_expert.embed_image(img)
|
|
931
|
+
img_emb = img_emb.to(dtype=torch.bfloat16)
|
|
932
|
+
|
|
933
|
+
# image embeddings don't need to be unnormalized because `fix/lerobot_openpi` branch of huggingface
|
|
934
|
+
# already removed the normalization inside PaliGemma
|
|
935
|
+
pass
|
|
936
|
+
|
|
937
|
+
bsize, num_img_embs = img_emb.shape[:2]
|
|
938
|
+
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
|
939
|
+
|
|
940
|
+
embs.append(img_emb)
|
|
941
|
+
pad_masks.append(img_mask)
|
|
942
|
+
|
|
943
|
+
# Create attention masks so that image tokens attend to each other
|
|
944
|
+
att_masks += [0] * num_img_embs
|
|
945
|
+
|
|
946
|
+
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
|
947
|
+
|
|
948
|
+
# Normalize language embeddings
|
|
949
|
+
lang_emb_dim = lang_emb.shape[-1]
|
|
950
|
+
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
|
951
|
+
|
|
952
|
+
embs.append(lang_emb)
|
|
953
|
+
pad_masks.append(lang_masks)
|
|
954
|
+
|
|
955
|
+
# full attention between image and language inputs
|
|
956
|
+
num_lang_embs = lang_emb.shape[1]
|
|
957
|
+
att_masks += [0] * num_lang_embs
|
|
958
|
+
|
|
959
|
+
if discrete_actions is not None:
|
|
960
|
+
discrete_action_emb = self.paligemma_with_expert.embed_discrete_actions(discrete_actions)
|
|
961
|
+
embs.append(discrete_action_emb.to(dtype=torch.bfloat16))
|
|
962
|
+
pad_masks.append(discrete_action_masks)
|
|
963
|
+
att_masks += [1] * discrete_action_emb.shape[1]
|
|
964
|
+
|
|
965
|
+
embs = torch.cat(embs, dim=1)
|
|
966
|
+
pad_masks = torch.cat(pad_masks, dim=1)
|
|
967
|
+
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
968
|
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
969
|
+
|
|
970
|
+
return embs, pad_masks, att_masks
|
|
971
|
+
|
|
972
|
+
def embed_suffix(self, noisy_actions: Tensor, timestep: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
973
|
+
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing.
|
|
974
|
+
|
|
975
|
+
Args:
|
|
976
|
+
noisy_actions: Tensor containing noisy actions.
|
|
977
|
+
timestep: Tensor containing timesteps.
|
|
978
|
+
|
|
979
|
+
Returns:
|
|
980
|
+
A tuple containing:
|
|
981
|
+
- embs: Concatenated embeddings tensor.
|
|
982
|
+
- pad_masks: Concatenated padding masks tensor.
|
|
983
|
+
- att_masks: Attention masks tensor.
|
|
984
|
+
- adarms_cond: AdaRMS conditioning tensor.
|
|
985
|
+
"""
|
|
986
|
+
embs = []
|
|
987
|
+
pad_masks = []
|
|
988
|
+
att_masks = []
|
|
989
|
+
|
|
990
|
+
bsize = noisy_actions.shape[0]
|
|
991
|
+
dtype = torch.bfloat16
|
|
992
|
+
device = noisy_actions.device
|
|
993
|
+
|
|
994
|
+
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
|
995
|
+
time_emb = create_sinusoidal_pos_embedding(
|
|
996
|
+
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
# Fuse timestep + action information using an MLP
|
|
1000
|
+
noisy_actions = noisy_actions.to(dtype=dtype)
|
|
1001
|
+
action_emb = self.action_in_proj(noisy_actions)
|
|
1002
|
+
|
|
1003
|
+
def time_mlp_func(time_emb):
|
|
1004
|
+
x = self.time_mlp_in(time_emb)
|
|
1005
|
+
x = F.silu(x)
|
|
1006
|
+
x = self.time_mlp_out(x)
|
|
1007
|
+
return F.silu(x)
|
|
1008
|
+
|
|
1009
|
+
time_emb = time_emb.to(dtype=dtype)
|
|
1010
|
+
adarms_cond = time_mlp_func(time_emb)
|
|
1011
|
+
|
|
1012
|
+
# Add to input tokens
|
|
1013
|
+
embs.append(action_emb)
|
|
1014
|
+
|
|
1015
|
+
bsize, action_dim = action_emb.shape[:2]
|
|
1016
|
+
action_mask = torch.ones(bsize, action_dim, dtype=torch.bool, device=device)
|
|
1017
|
+
pad_masks.append(action_mask)
|
|
1018
|
+
|
|
1019
|
+
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
|
1020
|
+
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
|
|
1021
|
+
|
|
1022
|
+
embs = torch.cat(embs, dim=1)
|
|
1023
|
+
pad_masks = torch.cat(pad_masks, dim=1)
|
|
1024
|
+
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
|
1025
|
+
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
1026
|
+
|
|
1027
|
+
return embs, pad_masks, att_masks, adarms_cond
|
|
1028
|
+
|
|
1029
|
+
def forward(
|
|
1030
|
+
self,
|
|
1031
|
+
images: list[Tensor],
|
|
1032
|
+
img_masks: list[Tensor],
|
|
1033
|
+
lang_tokens: Tensor,
|
|
1034
|
+
lang_masks: Tensor,
|
|
1035
|
+
actions: Tensor,
|
|
1036
|
+
noise: Tensor | None = None,
|
|
1037
|
+
time: Tensor | None = None,
|
|
1038
|
+
discrete_actions: Tensor | None = None,
|
|
1039
|
+
discrete_action_masks: Tensor | None = None,
|
|
1040
|
+
) -> dict[str, Tensor]:
|
|
1041
|
+
"""Do a full training forward pass and compute the loss.
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
images: List of image tensors.
|
|
1045
|
+
img_masks: List of image mask tensors.
|
|
1046
|
+
lang_tokens: Language token tensor.
|
|
1047
|
+
lang_masks: Language mask tensor.
|
|
1048
|
+
actions: Action tensor.
|
|
1049
|
+
noise: Optional noise tensor.
|
|
1050
|
+
time: Optional time tensor.
|
|
1051
|
+
discrete_actions: Optional discrete action tensor.
|
|
1052
|
+
discrete_action_masks: Optional discrete action mask tensor.
|
|
1053
|
+
|
|
1054
|
+
Returns:
|
|
1055
|
+
A dictionary containing the loss components ("MSE" and "CE").
|
|
1056
|
+
"""
|
|
1057
|
+
# Run VLM first to get key value cache
|
|
1058
|
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
1059
|
+
images, img_masks, lang_tokens, lang_masks, discrete_actions, discrete_action_masks
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
vlm_2d_attention_mask = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
1063
|
+
vlm_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
1064
|
+
|
|
1065
|
+
num_cross_att_tokens = prefix_embs.shape[1] - self.config.discrete_action_max_length
|
|
1066
|
+
|
|
1067
|
+
(prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
|
|
1068
|
+
attention_mask=vlm_2d_attention_mask,
|
|
1069
|
+
position_ids=vlm_position_ids,
|
|
1070
|
+
past_key_values=None,
|
|
1071
|
+
inputs_embeds=[prefix_embs, None],
|
|
1072
|
+
n_cross_att_tokens=num_cross_att_tokens,
|
|
1073
|
+
use_cache=True,
|
|
1074
|
+
fill_kv_cache=True,
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
# Now run action expert
|
|
1078
|
+
if noise is None:
|
|
1079
|
+
noise = self.sample_noise(actions.shape, actions.device)
|
|
1080
|
+
|
|
1081
|
+
if time is None:
|
|
1082
|
+
time = self.sample_time(actions.shape[0], actions.device)
|
|
1083
|
+
|
|
1084
|
+
time_expanded = time[:, None, None]
|
|
1085
|
+
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
|
1086
|
+
u_t = noise - actions
|
|
1087
|
+
|
|
1088
|
+
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
|
1089
|
+
|
|
1090
|
+
action_expert_2d_attention_mask = make_att_2d_masks(
|
|
1091
|
+
suffix_pad_masks,
|
|
1092
|
+
suffix_att_masks,
|
|
1093
|
+
n_cross_att_tokens=num_cross_att_tokens,
|
|
1094
|
+
cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens],
|
|
1095
|
+
)
|
|
1096
|
+
# We should skip the response tokens when numbering the position ids for the action expert
|
|
1097
|
+
prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[
|
|
1098
|
+
:, None
|
|
1099
|
+
] # action expert position ids start after prefix
|
|
1100
|
+
action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
1101
|
+
|
|
1102
|
+
# stop gradient to avoid backpropagating from action expert to VLM
|
|
1103
|
+
for layer_idx in past_key_values:
|
|
1104
|
+
past_key_values[layer_idx]["key_states"] = past_key_values[layer_idx]["key_states"].detach()
|
|
1105
|
+
past_key_values[layer_idx]["value_states"] = past_key_values[layer_idx]["value_states"].detach()
|
|
1106
|
+
|
|
1107
|
+
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
|
1108
|
+
attention_mask=action_expert_2d_attention_mask,
|
|
1109
|
+
position_ids=action_expert_position_ids,
|
|
1110
|
+
past_key_values=past_key_values,
|
|
1111
|
+
inputs_embeds=[None, suffix_embs],
|
|
1112
|
+
use_cache=True,
|
|
1113
|
+
fill_kv_cache=False,
|
|
1114
|
+
adarms_cond=[None, adarms_cond],
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
# compute mse loss for velocity
|
|
1118
|
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
|
1119
|
+
# Original openpi code, upcast attention output
|
|
1120
|
+
v_t = self.action_out_proj(suffix_out)
|
|
1121
|
+
v_t = v_t.to(dtype=torch.float32)
|
|
1122
|
+
|
|
1123
|
+
losses = F.mse_loss(u_t, v_t, reduction="none")
|
|
1124
|
+
|
|
1125
|
+
# compute cross entropy loss for discrete actions
|
|
1126
|
+
batch_size, seq_len = discrete_actions.shape
|
|
1127
|
+
discrete_action_out = prefix_out[:, -self.config.discrete_action_max_length - 1 : -1]
|
|
1128
|
+
logits = self.paligemma_with_expert.da_head(discrete_action_out)
|
|
1129
|
+
|
|
1130
|
+
logits = logits.to(dtype=torch.float32) # upcast to float32 for loss calculation
|
|
1131
|
+
logits = rearrange(logits, "b s d -> (b s) d")
|
|
1132
|
+
labels = rearrange(discrete_actions, "b s -> (b s)")
|
|
1133
|
+
ce_loss = F.cross_entropy(logits, labels, reduction="none")
|
|
1134
|
+
|
|
1135
|
+
ce_loss = rearrange(ce_loss, "(b s) -> b s", b=batch_size, s=seq_len)
|
|
1136
|
+
|
|
1137
|
+
# remove pad tokens
|
|
1138
|
+
discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True
|
|
1139
|
+
ce_loss = ce_loss * ~discrete_action_is_pad
|
|
1140
|
+
|
|
1141
|
+
# compute mean
|
|
1142
|
+
ce_loss = ce_loss.mean()
|
|
1143
|
+
|
|
1144
|
+
return {"MSE": losses, "CE": ce_loss}
|
|
1145
|
+
|
|
1146
|
+
def sample_actions(
|
|
1147
|
+
self,
|
|
1148
|
+
images: list[Tensor],
|
|
1149
|
+
img_masks: list[Tensor],
|
|
1150
|
+
lang_tokens: Tensor,
|
|
1151
|
+
lang_masks: Tensor,
|
|
1152
|
+
noise: Tensor | None = None,
|
|
1153
|
+
) -> Tensor:
|
|
1154
|
+
"""Do a full inference forward and compute the action.
|
|
1155
|
+
|
|
1156
|
+
Args:
|
|
1157
|
+
images: List of image tensors.
|
|
1158
|
+
img_masks: List of image mask tensors.
|
|
1159
|
+
lang_tokens: Language token tensor.
|
|
1160
|
+
lang_masks: Language mask tensor.
|
|
1161
|
+
noise: Optional noise tensor.
|
|
1162
|
+
|
|
1163
|
+
Returns:
|
|
1164
|
+
The sampled action tensor.
|
|
1165
|
+
"""
|
|
1166
|
+
bsize = lang_tokens.shape[0]
|
|
1167
|
+
device = lang_tokens.device
|
|
1168
|
+
|
|
1169
|
+
if noise is None:
|
|
1170
|
+
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
|
1171
|
+
noise = self.sample_noise(actions_shape, device)
|
|
1172
|
+
|
|
1173
|
+
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
1174
|
+
images, img_masks, lang_tokens, lang_masks
|
|
1175
|
+
)
|
|
1176
|
+
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
1177
|
+
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
1178
|
+
|
|
1179
|
+
num_cross_att_tokens = prefix_embs.shape[1]
|
|
1180
|
+
|
|
1181
|
+
# Compute image and language key value cache
|
|
1182
|
+
_, past_key_values = self.paligemma_with_expert.forward(
|
|
1183
|
+
attention_mask=prefix_att_2d_masks,
|
|
1184
|
+
position_ids=prefix_position_ids,
|
|
1185
|
+
past_key_values=None,
|
|
1186
|
+
inputs_embeds=[prefix_embs, None],
|
|
1187
|
+
n_cross_att_tokens=num_cross_att_tokens,
|
|
1188
|
+
use_cache=self.config.use_cache,
|
|
1189
|
+
fill_kv_cache=True,
|
|
1190
|
+
)
|
|
1191
|
+
|
|
1192
|
+
dt = -1.0 / self.config.num_steps
|
|
1193
|
+
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
1194
|
+
|
|
1195
|
+
x_t = noise
|
|
1196
|
+
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
|
1197
|
+
while time >= -dt / 2:
|
|
1198
|
+
expanded_time = time.expand(bsize)
|
|
1199
|
+
v_t = self.denoise_step(
|
|
1200
|
+
prefix_pad_masks,
|
|
1201
|
+
past_key_values,
|
|
1202
|
+
x_t,
|
|
1203
|
+
expanded_time,
|
|
1204
|
+
)
|
|
1205
|
+
|
|
1206
|
+
# Euler step
|
|
1207
|
+
x_t += dt * v_t
|
|
1208
|
+
time += dt
|
|
1209
|
+
return x_t
|
|
1210
|
+
|
|
1211
|
+
def denoise_step(
|
|
1212
|
+
self,
|
|
1213
|
+
prefix_pad_masks: Tensor,
|
|
1214
|
+
past_key_values: list[dict[str, Tensor]],
|
|
1215
|
+
x_t: Tensor,
|
|
1216
|
+
timestep: Tensor,
|
|
1217
|
+
) -> Tensor:
|
|
1218
|
+
"""Apply one denoising step of the noise `x_t` at a given timestep.
|
|
1219
|
+
|
|
1220
|
+
Args:
|
|
1221
|
+
prefix_pad_masks: Prefix padding masks.
|
|
1222
|
+
past_key_values: Past key values from the VLM.
|
|
1223
|
+
x_t: Current noise tensor.
|
|
1224
|
+
timestep: Current timestep.
|
|
1225
|
+
|
|
1226
|
+
Returns:
|
|
1227
|
+
The predicted velocity tensor (v_t).
|
|
1228
|
+
"""
|
|
1229
|
+
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
|
|
1230
|
+
|
|
1231
|
+
num_cross_att_tokens = prefix_pad_masks.shape[1]
|
|
1232
|
+
action_expert_2d_attention_mask = make_att_2d_masks(
|
|
1233
|
+
suffix_pad_masks,
|
|
1234
|
+
suffix_att_masks,
|
|
1235
|
+
n_cross_att_tokens=num_cross_att_tokens,
|
|
1236
|
+
cross_att_pad_masks=prefix_pad_masks[:, :num_cross_att_tokens],
|
|
1237
|
+
)
|
|
1238
|
+
# We should skip the response tokens when numbering the position ids for the action expert
|
|
1239
|
+
prefix_offsets = torch.sum(prefix_pad_masks[:, : -self.config.discrete_action_max_length], dim=-1)[
|
|
1240
|
+
:, None
|
|
1241
|
+
] # action expert position ids start after prefix
|
|
1242
|
+
action_expert_position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
1243
|
+
|
|
1244
|
+
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
|
1245
|
+
attention_mask=action_expert_2d_attention_mask,
|
|
1246
|
+
position_ids=action_expert_position_ids,
|
|
1247
|
+
past_key_values=past_key_values,
|
|
1248
|
+
inputs_embeds=[None, suffix_embs],
|
|
1249
|
+
use_cache=True,
|
|
1250
|
+
fill_kv_cache=False,
|
|
1251
|
+
adarms_cond=[None, adarms_cond],
|
|
1252
|
+
)
|
|
1253
|
+
suffix_out = outputs_embeds[1]
|
|
1254
|
+
suffix_out = suffix_out[:, -self.config.n_action_steps :]
|
|
1255
|
+
v_t = self.action_out_proj(suffix_out)
|
|
1256
|
+
v_t = v_t.to(dtype=torch.float32)
|
|
1257
|
+
return v_t
|